def main(): # tensorboard writer """ os.system('rm -rf ./runs/*') writer = SummaryWriter('./runs/'+datetime.now().strftime('%B%d %H:%M:%S')) if not os.path.exists('./runs'): os.mkdir('./runs') std = [.229, .224, .225] mean = [.485, .456, .406] """ train_dir = opt.train_dir val_dir = opt.val_dir check_dir = opt.check_dir bsize = opt.b iter_num = opt.e # training iterations if not os.path.exists(check_dir): os.mkdir(check_dir) # models if opt.q == 'vgg': feature = vgg.vgg(pretrained=True) elif 'resnet' in opt.q: feature = getattr(resnet, opt.q)(pretrained=True) elif 'densenet' in opt.q: feature = getattr(densenet, opt.q)(pretrained=True) else: feature = None feature.cuda() deconv = Deconv(opt.q) deconv.cuda() train_loader = torch.utils.data.DataLoader(MyData(train_dir, transform=True, crop=False, hflip=False, vflip=False), batch_size=bsize, shuffle=True, num_workers=4, pin_memory=True) val_loader = torch.utils.data.DataLoader(MyData(val_dir, transform=True, crop=False, hflip=False, vflip=False), batch_size=bsize / 2, shuffle=True, num_workers=4, pin_memory=True) if 'resnet' in opt.q: lr = 5e-3 lr_decay = 0.9 optimizer = torch.optim.SGD([{ 'params': [ param for name, param in deconv.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * lr }, { 'params': [ param for name, param in deconv.named_parameters() if name[-4:] != 'bias' ], 'lr': lr, 'weight_decay': 1e-4 }, { 'params': [ param for name, param in feature.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * lr }, { 'params': [ param for name, param in feature.named_parameters() if name[-4:] != 'bias' ], 'lr': lr, 'weight_decay': 1e-4 }], momentum=0.9, nesterov=True) else: optimizer = torch.optim.Adam([ { 'params': feature.parameters(), 'lr': 1e-4 }, { 'params': deconv.parameters(), 'lr': 1e-3 }, ]) min_loss = 10000.0 for it in range(iter_num): if 'resnet' in opt.q: optimizer.param_groups[0]['lr'] = 2 * lr * ( 1 - float(it) / iter_num)**lr_decay # bias optimizer.param_groups[1]['lr'] = lr * ( 1 - float(it) / iter_num)**lr_decay # weight optimizer.param_groups[2]['lr'] = 2 * lr * ( 1 - float(it) / iter_num)**lr_decay # bias optimizer.param_groups[3]['lr'] = lr * ( 1 - float(it) / iter_num)**lr_decay # weight for ib, (data, lbl) in enumerate(train_loader): inputs = Variable(data).cuda() lbl = Variable(lbl.float().unsqueeze(1)).cuda() feats = feature(inputs) msk = deconv(feats) loss = F.binary_cross_entropy_with_logits(msk, lbl) deconv.zero_grad() feature.zero_grad() loss.backward() optimizer.step() # visualize """ if ib % 100 ==0: # visulize image = make_image_grid(inputs.data[:4, :3], mean, std) writer.add_image('Image', torchvision.utils.make_grid(image), ib) msk = F.sigmoid(msk) mask1 = msk.data[:4] mask1 = mask1.repeat(1, 3, 1, 1) writer.add_image('Image2', torchvision.utils.make_grid(mask1), ib) mask1 = lbl.data[:4] mask1 = mask1.repeat(1, 3, 1, 1) writer.add_image('Label', torchvision.utils.make_grid(mask1), ib) writer.add_scalar('M_global', loss.data[0], ib) """ print('loss: %.4f (epoch: %d, step: %d)' % (loss.data[0], it, ib)) del inputs, msk, lbl, loss, feats gc.collect() sb = validation(feature, deconv, val_loader) if sb < min_loss: filename = ('%s/deconv.pth' % (check_dir)) torch.save(deconv.state_dict(), filename) filename = ('%s/feature.pth' % (check_dir)) torch.save(feature.state_dict(), filename) print('save: (epoch: %d)' % it) min_loss = sb
for ib, (data, _, lbl) in enumerate(train_loader): inputs = Variable(data) # inputs = Variable(data).cuda() # lbl = Variable(lbl.unsqueeze(1)).cuda() lbl = Variable(lbl.unsqueeze(1)) loss = 0 feats = feature(inputs) feats = feats[-3:] feats = feats[::-1] msk = deconv(feats) msk = functional.upsample(msk, scale_factor=4) prior = functional.sigmoid(msk) loss += criterion(msk, lbl) deconv.zero_grad() feature.zero_grad() loss.backward() optimizer_feature.step() optimizer_deconv.step() # visulize image = make_image_grid(inputs.data[:, :3], mean, std) writer.add_image('Image', torchvision.utils.make_grid(image), ib) msk = functional.sigmoid(msk) mask1 = msk.data # mskdata,分割出来的。 mask1 = mask1.repeat(1, 3, 1, 1) acc = math.e**(0 - loss) writer.add_image('Image2', torchvision.utils.make_grid(mask1), ib)