def build_model(self): self.net = build_model(self.config.arch) # 是否将网络搬运至cuda if self.config.cuda: self.net = self.net.cuda() # self.net.train() # 设置eval状态 self.net.eval() # use_global_stats = True # 网络权重初始化 self.net.apply(weights_init) # 载入预训练模型或自行训练模型 if self.config.load == '': self.net.base.load_pretrained_model( torch.load(self.config.pretrained_model)) else: self.net.load_state_dict(torch.load(self.config.load)) # 学习率 self.lr = self.config.lr # 权值衰减 self.wd = self.config.wd # 设置优化器 self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) # 打印网络结构 self.print_network(self.net, 'PoolNet Structure')
def build_model(self): net = build_model(self.arch) if torch.cuda.is_available(): net = net.cuda() # BN层中有个参数use_global_stats,它表示是否使用caffe内部的均值和方差。 # 训练模型的时候,将BN层use_global_stats设置为false;测试的时候设置为true,不然训练的时候会报nan或者模型不收敛。 net.eval() # use_global_stats = True net.apply(self.weights_init) if self.pretrained_model: net.base.load_pretrained_model(torch.load(self.pretrained_model)) pass self._print_network(net, 'PoolNet Structure') return net
def main(): # create the model model = build_model() model.to(device) model.load_state_dict(torch.load(args.restore_from)) # create domintor model_D1 = FCDiscriminator(num_classes=1) model_D1.to(device) model_D1.load_state_dict(torch.load(args.D_restore_from)) up = torch.nn.Upsample(scale_factor=32, mode='bilinear') sig = torch.nn.Sigmoid() # labels for adversarial training 两种域的记号 salLabel = 0 edgeLabel = 1 picloader = get_loader(args) correct = 0 tot = 0 for i_iter, data_batch in enumerate(picloader): tot += 2 sal_image, edge_image = data_batch['sal_image'], data_batch[ 'edge_image'] sal_image, edge_image = Variable(sal_image), Variable(edge_image) sal_image, edge_image = sal_image.to(device), edge_image.to(device) sal_pred = model(sal_image) edge_pred = model(edge_image) # test D # for param in model_D1.parameters(): # param.requires_grad = True ss_out = model_D1(sal_pred) se_out = model_D1(edge_pred) if pan(ss_out) == salLabel: correct += 1 if pan(se_out) == edgeLabel: correct += 1 if i_iter % 100 == 0: print('processing %d: %f' % (i_iter, correct / tot)) print(correct / tot)
def build_model(self): self.net = build_model(self.config.arch) if self.config.cuda: self.net = self.net.cuda() # self.net.train() self.net.eval() # use_global_stats = True self.net.apply(weights_init) if self.config.load == '': self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model)) else: self.net.load_state_dict(torch.load(self.config.load)) self.lr = self.config.lr self.wd = self.config.wd self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) self.print_network(self.net, 'PoolNet Structure')
def main(): gtdir = args.snapshot_dir + 'gt/1/' preddir = args.snapshot_dir + 'pred/tsy1/1/' # make dir if not os.path.exists(gtdir): os.makedirs(gtdir) if not os.path.exists(preddir): os.makedirs(preddir) # xuan xue you hua cudnn.enabled = True cudnn.benchmark = True # create the model model = build_model() model.to(device) model.train() model.apply(weights_init) model.load_state_dict(torch.load(args.restore_from)) picloader = get_loader(args, mode='test') for i_iter, data_batch in enumerate(picloader): if i_iter % 50 == 0: print(i_iter) sal_image, sal_label = data_batch['sal_image'], data_batch['sal_label'] with torch.no_grad(): sal_image = Variable(sal_image).to(device) preds = model(sal_image) pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy()) label = np.squeeze(sal_label.cpu().data.numpy()) multi_fuse = 255 * pred label = 255 * label cv2.imwrite(os.path.join(preddir, str(i_iter) + '.jpg'), multi_fuse) cv2.imwrite(os.path.join(gtdir, str(i_iter) + '.png'), label)
def test(arch, model_path, test_loader, result_fold): Tools.print('Loading trained model from {}'.format(model_path)) net = build_model(arch).cuda() net.load_state_dict(torch.load(model_path)) net.eval() time_s = time.time() img_num = len(test_loader) for i, data_batch in enumerate(test_loader): if i % 100 == 0: Tools.print("test {} {}".format(i, img_num)) images, name, im_size = data_batch['image'], data_batch['name'][ 0], np.asarray(data_batch['size']) with torch.no_grad(): images = torch.Tensor(images).cuda() pred = net(images) pred = np.squeeze(torch.sigmoid(pred).cpu().data.numpy()) * 255 cv2.imwrite(os.path.join(result_fold, name[:-4] + '.png'), pred) time_e = time.time() Tools.print('Speed: %f FPS' % (img_num / (time_e - time_s))) Tools.print('Test Done!') pass
def main(): # make dir if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) run = 0 while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)): run += 1 os.mkdir("%s/run-%d" % (args.snapshot_dir, run)) os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run)) args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run) args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run) # xuan xue you hua cudnn.enabled = True cudnn.benchmark = True # create the model model = build_model() model.to(device) model.train() model.apply(weights_init) model.load_state_dict(torch.load(args.restore_from)) # model.base.load_pretrained_model(torch.load(args.pretrained_model)) # create domintor model_D1 = FCDiscriminator(num_classes=1).to(device) model_D1.train() model_D1.apply(weights_init) # model_D1.load_state_dict(torch.load(args.D_restore_from)) # create optimizer optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) # 整个模型的优化器 optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() # uneccessery bce_loss = torch.nn.BCEWithLogitsLoss() # start time with open(args.file_dir, 'a') as f: f.write('strat time: ' + str(datetime.now()) + '\n\n') f.write('learning rate: ' + str(args.learning_rate) + '\n') f.write('learning rate D: ' + str(args.learning_rate_D) + '\n') f.write('wight decay: ' + str(args.weight_decay) + '\n') f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n') f.write('eptch size: ' + str(args.epotch_size) + '\n') f.write('batch size: ' + str(args.batch_size) + '\n') f.write('iter size: ' + str(args.iter_size) + '\n') f.write('num steps: ' + str(args.num_steps) + '\n\n') # labels for adversarial training 两种域的记号 salLabel = 0 edgeLabel = 1 picloader = get_loader(args) iter_num = len(picloader.dataset) // args.batch_size aveGrad = 0 for i_epotch in range(args.epotch_size): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 model.zero_grad() for i_iter, data_batch in enumerate(picloader): sal_image, sal_label, edge_image = data_batch[ 'sal_image'], data_batch['sal_label'], data_batch['edge_image'] if (sal_image.size(2) != sal_label.size(2)) or ( sal_image.size(3) != sal_label.size(3)): print('IMAGE ERROR, PASSING```') with open(args.file_dir, 'a') as f: f.write('IMAGE ERROR, PASSING```\n') continue sal_image, sal_label, edge_image = Variable(sal_image), Variable( sal_label), Variable(edge_image) sal_image, sal_label, edge_image = sal_image.to( device), sal_label.to(device), edge_image.to(device) sal_pred = model(sal_image) edge_pred = model(edge_image) # train G(with G) for param in model_D1.parameters(): param.requires_grad = False sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size) loss_seg_value1 += sal_loss.data sal_loss.backward() sD_out = model_D1(edge_pred) # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低 loss_adv_target1 = bce_loss( sD_out, torch.FloatTensor(sD_out.data.size()).fill_(salLabel).to( device)) # 后面一个相当于全部是正确答案的和前一个size相同的tensor sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size) loss_adv_target_value1 += sd_loss.data # 记录专用 sd_loss = sd_loss * args.lambda_adv_target2 sd_loss.backward() # train D for param in model_D1.parameters(): param.requires_grad = True sal_pred = sal_pred.detach() edge_pred = edge_pred.detach() ss_out = model_D1(sal_pred) ss_loss = bce_loss( ss_out, torch.FloatTensor( ss_out.data.size()).fill_(salLabel).to(device)) ss_Loss = ss_loss / (args.iter_size * args.batch_size) loss_D_value1 += ss_Loss.data ss_Loss.backward() se_out = model_D1(edge_pred) se_loss = bce_loss( se_out, torch.FloatTensor( se_out.data.size()).fill_(edgeLabel).to(device)) se_Loss = se_loss / (args.iter_size * args.batch_size) loss_D_value1 += se_Loss.data se_Loss.backward() aveGrad += 1 if aveGrad % args.iter_size == 0: optimizer.step() optimizer.zero_grad() optimizer_D1.step() optimizer_D1.zero_grad() aveGrad = 0 if i_iter % (args.show_every // args.batch_size) == 0: print( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}' .format(i_iter, iter_num, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) with open(args.file_dir, 'a') as f: f.write( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n' .format(i_iter, iter_num, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0 if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') with open(args.file_dir, 'a') as f: f.write('taking snapshot ...\n') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join( args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '_D1.pth')) if i_epotch == 7: args.learning_rate = args.learning_rate * 0.1 args.learning_rate_D = args.learning_rate_D * 0.1 optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # end with open(args.file_dir, 'a') as f: f.write('end time: ' + str(datetime.now()) + '\n')
def main(): # make dir if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) run = 0 while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)): run += 1 os.mkdir("%s/run-%d" % (args.snapshot_dir, run)) os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run)) args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run) args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run) # xuan xue you hua cudnn.enabled = True cudnn.benchmark = True # create the model model = build_model() model.to(device) model.train() model.apply(weights_init) model.load_state_dict(torch.load(args.restore_from)) # model.base.load_pretrained_model(torch.load(args.pretrained_model)) # create optimizer optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) # 整个模型的优化器 optimizer.zero_grad() # start time with open(args.file_dir, 'a') as f: f.write('strat time: ' + str(datetime.now()) + '\n\n') f.write('learning rate: ' + str(args.learning_rate) + '\n') f.write('learning rate D: ' + str(args.learning_rate_D) + '\n') f.write('wight decay: ' + str(args.weight_decay) + '\n') f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n') f.write('eptch size: ' + str(args.epotch_size) + '\n') f.write('batch size: ' + str(args.batch_size) + '\n') f.write('iter size: ' + str(args.iter_size) + '\n') f.write('num steps: ' + str(args.num_steps) + '\n\n') picloader = get_loader(args) iter_num = len(picloader.dataset) // args.batch_size aveGrad = 0 for i_epotch in range(args.epotch_size): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 model.zero_grad() for i_iter, data_batch in enumerate(picloader): sal_image, sal_label = data_batch['sal_image'], data_batch[ 'sal_label'] if (sal_image.size(2) != sal_label.size(2)) or ( sal_image.size(3) != sal_label.size(3)): print('IMAGE ERROR, PASSING```') with open(args.file_dir, 'a') as f: f.write('IMAGE ERROR, PASSING```\n') continue sal_image, sal_label = Variable(sal_image), Variable(sal_label) sal_image, sal_label = sal_image.to(device), sal_label.to(device) sal_pred = model(sal_image) sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size) loss_seg_value1 += sal_loss.data sal_loss.backward() aveGrad += 1 if aveGrad % args.iter_size == 0: optimizer.step() optimizer.zero_grad() aveGrad = 0 if i_iter % (args.show_every // args.batch_size) == 0: print( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}' .format(i_iter, iter_num, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) with open(args.file_dir, 'a') as f: f.write( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n' .format(i_iter, iter_num, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0 if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') with open(args.file_dir, 'a') as f: f.write('taking snapshot ...\n') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth')) # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '_D1.pth')) if i_epotch == 7: args.learning_rate = args.learning_rate * 0.1 optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) # end with open(args.file_dir, 'a') as f: f.write('end time: ' + str(datetime.now()) + '\n')
def main(): """Create the model and start the training.""" # device放GPU还是CPU device = torch.device("cuda" if not args.cpu else "cpu") cudnn.enabled = True #一种玄学优化 # Create network # 重要一步 输入类别数 model = build_model() # 生成一个由resnet组成的语义分割模型 # 读取pretrained模型 model.to(device) model.train() model.apply(weights_init) # model.load_state_dict(torch.load(args.restore_from)) model.base.load_pretrained_model(torch.load(args.pretrained_model)) #设置model参数 # 玄学优化 cudnn.benchmark = True # init D 设置D 鉴别器 ''' model_D1 = FCDiscriminator(num_classes=1).to(device) model_D1.train() model_D1.to(device) model_D1.apply(weights_init) ''' # 创建存放模型的文件夹 if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) picloader = get_loader(args) # implement model.optim_parameters(args) to handle different models' lr setting # 优化器 optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay) # 整个模型的优化器 optimizer.zero_grad() #optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # D1的优化器 # optimizer_D1.zero_grad() # 损失函数 bce_loss = torch.nn.BCEWithLogitsLoss() # sigmoid + BCE的完美组合 ''' # 两个改变size的上采样 interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) # 变为source input的上采样 interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # 变为 target input的上采样 ''' # save folder run = 0 while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)): run += 1 os.mkdir("%s/run-%d" % (args.snapshot_dir, run)) os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run)) args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run) args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run) # labels for adversarial training 两种域的记号 source_label = 0 target_label = 1 with open(args.file_dir, 'a') as f: f.write('strat time: ' + str(datetime.now()) + '\n\n') f.write('learning rate: ' + str(args.learning_rate) + '\n') f.write('learning rate D: ' + str(args.learning_rate_D) + '\n') f.write('wight decay: ' + str(args.weight_decay) + '\n') f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n') f.write('eptch size: ' + str(args.epotch_size) + '\n') f.write('batch size: ' + str(args.batch_size) + '\n') f.write('iter size: ' + str(args.iter_size) + '\n') f.write('num steps: ' + str(args.num_steps) + '\n\n') for i_epotch in range(args.epotch_size): # 损失值置零 loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 model.zero_grad() # model_D1.zero_grad() loader_iter = enumerate(picloader) for i_iter in range(args.num_steps // args.batch_size // args.iter_size): # 迭代次数 大batch # 优化器梯度置零 + 调整学习率 optimizer.zero_grad() # adjust_learning_rate(optimizer, i_iter) # optimizer_D1.zero_grad() # adjust_learning_rate_D(optimizer_D1, i_iter) for sub_i in range(args.iter_size): # 迭代次数 小batch # get picture _, data_batch = loader_iter.__next__() # 获取一组图片 source_images, source_labels, target_images = data_batch[ 'sal_image'], data_batch['sal_label'], data_batch[ 'edge_image'] #, data_batch['edge_label'] source_images, source_labels, target_images = Variable( source_images), Variable(source_labels), Variable( target_images) if (source_images.size(2) != source_labels.size(2)) or ( source_images.size(3) != source_labels.size(3)): print('IMAGE ERROR, PASSING```') with open(args.file_dir, 'a') as f: f.write('IMAGE ERROR, PASSING```\n') continue # 放入GPU source_images = source_images.to(device) source_labels = source_labels.to(device) target_images = target_images.to(device) pred1 = model( source_images) # 三层block和四层block之后classify之后的结果(相当于两种层的结果) # pred_target1 = model(target_images) # 放入模型 # train G # don't accumulate grads in D 不需要D的梯度,因为这里是用D来辅助训练G # for param in model_D1.parameters(): # param.requires_grad = False # train with source # 计算损失函数 loss_seg1 = F.binary_cross_entropy_with_logits(pred1, source_labels, reduction='sum') lossG = loss_seg1 / args.iter_size / args.batch_size loss_seg_value1 += lossG.item() # 记录这次的iter的结果,显示相关和训练不相关 lossG.backward() ''' # D_out1 = model_D1(F.softmax(pred_target1)) # 放入鉴别器(不知道为什么要softmax) D_out1 = model_D1(pred_target1) # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低 loss_adv_target1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) # 后面一个相当于全部是正确答案的和前一个size相同的tensor lossD = loss_adv_target1 / args.iter_size loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # 记录专用 lossD = lossD * args.lambda_adv_target2 lossD.backward() # train D # bring back requires_grad 恢复D的grad for param in model_D1.parameters(): param.requires_grad = True pred1 = pred1.detach()# train with source 脱离grad # D_out1 = model_D1(F.softmax(pred1))# sourse的判别结果 D_out1 = model_D1(pred1) # 训练D时sourse判断成sourse损失函数低 loss_Ds = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) loss_Ds = loss_Ds / args.iter_size loss_D_value1 += loss_Ds.item()# 显示专用 pred_target1 = pred_target1.detach()# train with target target数据训练 脱离 # D_out1 = model_D1(F.softmax(pred_target1))# 得到判别结果 D_out1 = model_D1(pred_target1)# 得到判别结果 # taget判别为target时损失函数低 loss_Dt = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device)) loss_Dt = loss_Dt / args.iter_size loss_D_value1 += loss_Dt.item()# 显示专用 loss_Ds.backward() loss_Dt.backward() ''' # 修改一次参数 optimizer.step() # optimizer_D1.step() ''' # 不管 if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) ''' # 显示 if i_iter * args.batch_size % SHOW_EVERY == 0: print('exp = {}'.format(args.snapshot_dir)) print( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}' .format( i_iter, args.num_steps // args.batch_size // args.iter_size, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) with open(args.file_dir, 'a') as f: f.write('exp = {}\n'.format(args.snapshot_dir)) f.write( 'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n' .format( i_iter, args.num_steps // args.batch_size // args.iter_size, loss_seg_value1, loss_adv_target_value1, loss_D_value1, i_epotch, args.epotch_size)) loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0 # 提前终止 if i_iter >= args.num_steps_stop - 1: print('save model ...') with open(args.file_dir, 'a') as f: f.write('save model ...\n') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(args.num_steps_stop) + '.pth')) # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(args.num_steps_stop) + '_D1.pth')) break if i_iter == args.num_steps // args.batch_size // args.iter_size - 1 or i_iter * args.batch_size * args.iter_size % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') with open(args.file_dir, 'a') as f: f.write('taking snapshot ...\n') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth')) # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '_D1.pth')) ''' if args.tensorboard: writer.close() ''' with open(args.file_dir, 'a') as f: f.write('end time: ' + str(datetime.now()) + '\n')