avg_cost = np.zeros([total_epoch, 24], dtype=np.float32) best_loss = 100 for epoch in range(total_epoch): index = epoch cost = np.zeros(24, dtype=np.float32) dist_loss_save = {} for i, t in enumerate(tasks): dist_loss_save[i] = AverageMeter() # apply Dynamic Weight Average bar = Bar('Training', max=train_batch) # iteration for all batches model.train() nyuv2_train_dataset = iter(nyuv2_train_loader) for k in range(train_batch): # pdb.set_trace() train_data, train_label, train_depth, train_normal = nyuv2_train_dataset.next( ) train_data, train_label = train_data.cuda(), train_label.type( torch.LongTensor).cuda() train_depth, train_normal = train_depth.cuda(), train_normal.cuda() train_pred, logsigma, feat_s = model(train_data) # pdb.set_trace() train_loss = model.model_fit(train_pred[0], train_label, train_pred[1], train_depth, train_pred[2], train_normal) loss = 0
def main(): if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) if not os.path.exists(args.train_debug_vis_dir): os.makedirs(args.train_debug_vis_dir) model = SegNet(model='resnet50') # freeze bn statics model.train() model.cuda() optimizer = torch.optim.SGD(params=[ { "params": get_params(model, key="backbone", bias=False), "lr": INI_LEARNING_RATE }, { "params": get_params(model, key="backbone", bias=True), "lr": 2 * INI_LEARNING_RATE }, { "params": get_params(model, key="added", bias=False), "lr": 10 * INI_LEARNING_RATE }, { "params": get_params(model, key="added", bias=True), "lr": 20 * INI_LEARNING_RATE }, ], lr=INI_LEARNING_RATE, weight_decay=WEIGHT_DECAY) dataloader = DataLoader(SegDataset(mode='train'), batch_size=8, shuffle=True, num_workers=4) global_step = 0 for epoch in range(1, EPOCHES): for i_iter, batch_data in enumerate(dataloader): global_step += 1 Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data optimizer.zero_grad() pred_mask = model(Input_image.cuda()) loss = loss_calc(pred_mask, gt_mask, weight_matrix) loss.backward() optimizer.step() if global_step % 10 == 0: print('epoche {} i_iter/total {}/{} loss {:.4f}'.format(\ epoch, i_iter, int(dataset_length[0].data), loss)) if global_step % 10000 == 0: vis_pred_result( vis_image, gt_mask, pred_mask, args.train_debug_vis_dir + str(global_step) + '.png') if global_step % 1e4 == 0: torch.save(model.state_dict(), args.snapshot_dir + str(global_step) + '.pth')