FixMatch(fixmatch复现)

网友投稿 450 2022-08-26


FixMatch(fixmatch复现)

paper:​​code: ​​等提出,结合了伪标签和一致性正则化,极大地简化了整个方法。它在广泛的基准测试中得到了最先进的结果。

如我们所见,我们在有标签图像上使用交叉熵损失训练一个监督模型。对于每一幅未标记的图像,分别采用弱增强和强增强方法得到两幅图像。弱增强的图像被传递给我们的模型,我们得到预测。把置信度最大的类的概率与阈值进行比较。如果它高于阈值,那么我们将这个类作为标签,即伪标签。然后,将强增强后的图像通过模型进行分类预测。该预测方法与基于交叉熵损失的伪标签的方法进行了比较。把两种损失合并来优化模型。

代码

train.py

class Trainer(object): def __init__(self, cfg): self.cfg = cfg ##########hyper parameters setting################# self.net = get_model(cfg.num_classes, cfg.model_name).to(device) optimizer = RAdam(params=self.net.parameters(), lr=cfg.lr, weight_decay=0.0001) self.optimizer = Lookahead(optimizer) milestones = [5 + x * 60 for x in range(5)] # print(f'milestones:{milestones}') scheduler_c = CyclicCosAnnealingLR(optimizer, milestones=milestones, eta_min=5e-5) self.scheduler = LearningRateWarmUP(optimizer=optimizer, target_iteration=5, target_lr=0.003, after_scheduler=scheduler_c) self.criterion = ComboLoss().to(device) self.G = GridMask(True, True) self.best_acc = -100 def load_net(self, path): self.net = torch.load(path, map_location='cuda:0')["model_state"] # self.best_acc = torch.load(path, map_location='cuda:0')["best_acc"] # print(f'best_acc: {self.best_acc}') def train_one_epoch(self, loader): num_samples = 0 running_loss = 0 trn_error = 0 self.net.train() for images, masks in loader: if self.cfg.cutMix: images, masks = cutmix(images, masks) if self.cfg.fmix: w, h = images.size(-1), images.size(-2) images, masks = fmix_seg(images, masks, alpha=1., decay_power=3., shape=(w, h)) images = images.to(device, dtype=torch.float) if self.cfg.Grid: images = self.G(images) masks = torch.squeeze(masks.to(device)) # print("images'size:{},masks'size:{}".format(images.size(),masks.size())) num_samples += int(images.size(0)) self.optimizer.zero_grad() outputs, cls = self.net(images) loss = self.criterion(outputs, masks, cls) loss.backward() batch_loss = loss.item() self.optimizer.step() running_loss += batch_loss pred = get_predictions(outputs) masks = masks.type(torch.cuda.LongTensor) masks = masks.data.cpu() trn_error += compute_error(pred, masks) return running_loss / len(loader), trn_error / len(loader) def validate(self, loader): num_samples = 0 running_loss = 0 trn_error = 0 self.net.eval() for images, masks in loader: images = images.to(device, dtype=torch.float) masks = torch.squeeze(masks.to(device)) num_samples += int(images.size(0)) outputs, cls = self.net(images) loss = self.criterion(outputs, masks, cls) batch_loss = loss.item() running_loss += batch_loss pred = get_predictions(outputs) masks = masks.type(torch.cuda.LongTensor) masks = masks.data.cpu() trn_error += compute_error(pred, masks) return running_loss / len(loader), trn_error / len(loader) def train(self): mkdir(self.cfg.model_save_path) ##########prepare dataset################################ train_loader, val_loader, test_loader = build_loader(self.cfg) for epoch in range(self.cfg.num_epochs): print("Epoch: {}/{}".format(epoch + 1, self.cfg.num_epochs)) # optimizer.step() self.scheduler.step(epoch) ####################train#################################### train_loss, train_error = self.train_one_epoch(train_loader) start = time.strftime("%H:%M:%S") print( f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ", f"Training Loss: {train_loss:.4f}.. ", f"Training Acc: {1 - train_error:.4f}.. ", ) ######################valid################################## val_loss, val_error = self.validate(val_loader) start = time.strftime("%H:%M:%S") print( f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ", f"validation Loss: {val_loss:.4f}.. ", f"validation Acc: {1 - val_error:.4f}.. ", ) if 1 - val_error > self.best_acc: state = { "epoch": epoch + 1, "model_state": self.net, "best_acc": 1 - val_error } checkpoint = f'{self.cfg.model_name}_best.pth' torch.save(state, os.path.join(self.cfg.model_save_path, checkpoint)) # save model print("The model has saved successfully!") self.best_acc = 1 - val_error def train_one_epoch_semi(self, trainloader, testloader): running_loss = 0 trn_error = 0 loader = zip(trainloader, testloader) self.net.train() for data_x, data_u in loader: images_x, targets_x = data_x images_u_w, images_u_s = data_u # cpu ==> gpu images_x = images_x.to(device, dtype=torch.float) targets_x = torch.squeeze(targets_x.to(device)) images_u_w = images_u_w.to(device, dtype=torch.float) images_u_s = images_u_s.to(device, dtype=torch.float) if self.cfg.Grid: images_x = self.G(images_x) images_u_s = self.G(images_u_s) # print("images'size:{},masks'size:{}".format(images.size(),masks.size())) self.optimizer.zero_grad() outputs_x, cls_x = self.net(images_x) outputs_u_w, cls_u_w = self.net(images_u_w) outputs_u_s, cls_u_s = self.net(images_u_s) # get pseudo label targets_u = outputs_u_w.ge(self.cfg.threshold).float() loss_x = self.criterion(outputs_x, targets_x, cls_x) loss_u = (self.criterion(outputs_u_s, torch.squeeze(targets_u), cls_x, reduction='none') * torch.squeeze(targets_u)).mean() loss = loss_x + self.cfg.lambda_u * loss_u loss.backward() batch_loss = loss.item() self.optimizer.step() running_loss += batch_loss pred = get_predictions(outputs_x) masks = targets_x.type(torch.cuda.LongTensor) masks = masks.data.cpu() trn_error += compute_error(pred, masks) return running_loss / len(trainloader), trn_error / len(trainloader) def train_semi(self): self.load_net(f'{self.cfg.model_save_path}/{self.cfg.model_name}_best.pth') model_save_path = self.cfg.model_save_path + '_semi' mkdir(model_save_path) ##########prepare dataset################################ train_loader, val_loader, test_loader = build_loader_v2(self.cfg) for epoch in range(self.cfg.num_epochs): print("Epoch: {}/{}".format(epoch + 1, self.cfg.num_epochs)) # optimizer.step() self.scheduler.step(epoch) ####################train#################################### train_loss, train_error = self.train_one_epoch_semi(train_loader, test_loader) start = time.strftime("%H:%M:%S") print( f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ", f"Training Loss: {train_loss:.4f}.. ", f"Training Acc: {1 - train_error:.4f}.. ", ) ######################valid################################## val_loss, val_error = self.validate(val_loader) start = time.strftime("%H:%M:%S") print( f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ", f"validation Loss: {val_loss:.4f}.. ", f"validation Acc: {1 - val_error:.4f}.. ", ) if 1 - val_error > self.best_acc: state = { "epoch": epoch + 1, "model_state": self.net, "best_acc": 1 - val_error } checkpoint = f'{self.cfg.model_name}_best.pth' torch.save(state, os.path.join(model_save_path, checkpoint)) # save model print("The model has saved successfully!") self.best_acc = 1 - val_error

dataset.py

from torch.utils.data import Dataset, DataLoaderfrom torch.utils.data.sampler import SubsetRandomSamplerimport torchimport torchvisionfrom torchvision.transforms import Composeimport numpy as npimport cv2 as cvimport osfrom random import samplefrom utils.transforms import *from utils.randaugment import *from utils.grid import Griddef img_to_tensor(img): tensor = torch.from_numpy(img.transpose((2, 0, 1))) return tensordef to_monochrome(x): # x_ = x.convert('L') x_ = np.array(x).astype(np.float32) # convert image to monochrome return x_def to_tensor(x): x_ = np.expand_dims(x, axis=0) x_ = torch.from_numpy(x_) return x_ImageToTensor = torchvision.transforms.ToTensordef custom_blur_demo(image): kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #锐化 dst = cv.filter2D(image, -1, kernel=kernel) return dstclass SasDataset(Dataset): def __init__(self, root, mode='train', is_ndvi=False): self.root = root self.mode = mode self.is_ndvi = is_ndvi self.imgList = sorted(img for img in os.listdir(self.root)) self.transform = DualCompose([ RandomFlip(), RandomRotate90(), Rotate(), Shift(), CoarseDropout() ]) self.RA = RandomAugment(2, 10) self.imgTransforms = Compose([img_to_tensor]) self.maskTransforms = Compose([ torchvision.transforms.Lambda(to_monochrome), torchvision.transforms.Lambda(to_tensor), ]) def __getitem__(self, idx): imgPath = os.path.join(self.root, self.imgList[idx]) img = np.load(imgPath) img = custom_blur_demo(img) imgName = os.path.split(imgPath)[-1].split('.')[0] if self.mode == 'test': batch_data = {'img': self.imgTransforms(img), 'file_name': imgName} return batch_data labelPath = imgPath.replace('images', 'labels').replace('npy', 'png') mask = cv.imread(labelPath)/255 # data augmentation if self.mode == 'train': img, mask = self.transform(img, mask) img = self.RA(img) # img, mask =img.astype(np.float), mask.astype(np.float) w, h = mask.shape[:2] mask = mask[:, :, 0] mask = np.reshape(mask, (w, h, 1)).transpose((2, 0, 1)) return self.imgTransforms(img), self.maskTransforms(np.squeeze(mask)) def __len__(self): return len(self.imgList)class USasDataset(Dataset): def __init__(self, root, mode='train'): self.root = root self.mode = mode self.imgList = sorted(img for img in os.listdir(self.root)) self.transform = DualCompose([ RandomFlip(), RandomRotate90(), Rotate(), Shift(), # Cutout(num_holes=20, max_h_size=20, max_w_size=20, fill_value=0) ]) self.RA = RandomAugment(2, 10) self.imgTransforms = Compose([ImageToTensor()]) def __getitem__(self, idx): imgPath = os.path.join(self.root, self.imgList[idx]) img = np.load(imgPath) img = custom_blur_demo(img) mask = np.zeros_like(img) # weak data augmentation img_w, _ = self.transform(img, mask) # serious data augmentation img_s = self.RA(img_w) return self.imgTransforms(img_w), self.imgTransforms(img_s) def __len__(self): return len(self.imgList)


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Python 文件模块(python怎么读)
下一篇:python init new
相关文章

 发表评论

暂时没有评论,来抢沙发吧~