'snapshot': '', 'pretrain': os.path.join(ckpt_path, 'VideoSaliency_2019-12-24 22:05:11', '50000.pth'), # 'pretrain': '', 'imgs_file': 'Pre-train/pretrain_all_seq_DUT_TR_DAFB2_DAVSOD2.txt', # 'imgs_file': 'video_saliency/train_all_DAFB2_DAVSOD_5f.txt', 'train_loader': 'both' # 'train_loader': 'video_sequence' } imgs_file = os.path.join(datasets_root, args['imgs_file']) # imgs_file = os.path.join(datasets_root, 'video_saliency/train_all_DAFB3_seq_5f.txt') joint_transform = joint_transforms.Compose([ joint_transforms.ImageResize(520), joint_transforms.RandomCrop(473), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomRotate(10) ]) # joint_seq_transform = joint_transforms.Compose([ # joint_transforms.ImageResize(520), # joint_transforms.RandomCrop(473) # ]) input_size = (473, 473) img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) target_transform = transforms.ToTensor()
def main(train_args): check_mkdir(os.path.join(train_args['ckpt_path'], args['exp'])) check_mkdir( os.path.join(train_args['ckpt_path'], args['exp'], train_args['exp_name'])) model = DeepLabV3('1') # print(model) device = torch.device("cuda") num_gpu = list(range(torch.cuda.device_count())) """###############------use gpu--------###############""" if args['use_gpu']: ts = time.time() print(torch.cuda.current_device()) print(torch.cuda.get_device_name(0)) model = nn.DataParallel(model, device_ids=num_gpu) model = model.to(device) print("Finish cuda loading ,time elapsed {}", format(time.time() - ts)) else: print("please check your gpu device,start training on cpu") """###############-------中间开始训练--------###############""" if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } # model.apply(weights_init) else: print("train resume from " + train_args['snapshot']) state_dict = torch.load( os.path.join(train_args['ckpt_path'], args['exp'], train_args['exp_name'], train_args['snapshot'])) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] new_state_dict[name] = v model.load_state_dict(new_state_dict) # model.load_state_dict( # torch.load(os.path.join(train_args['ckpt_path'],args['exp'],train_args['exp_name'], train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } model.train() mean_std = ([0.485, 0.456, 0.406, 0.450], [0.229, 0.224, 0.225, 0.225]) """#################---数据增强和数据变换等操作------########""" input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) ##Nomorlized target_transform = extended_transforms.MaskToTensor() # target to tensor joint_transform = joint_transforms.Compose([ joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomCrop((256, 256), padding=0), joint_transforms.Rotate(degree=90) ]) ###data_augment restore = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), extended_transforms.channel_4_to_channel_3(4, 3), ##默认3通道如果四通道会转成三通道 standard_transforms.ToPILImage(), ]) # DeNomorlized,出来是pil图片了 visualize = standard_transforms.Compose([ standard_transforms.Resize(256), standard_transforms.CenterCrop(256), ##中心裁剪,此处可以删除 standard_transforms.ToTensor() ]) # resize 大小之后转tensor """#################---数据加载------########""" train_set = yaogan(mode='train', cls=train_args['training_cls'], joint_transform=None, input_transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['batch_size'], num_workers=train_args['num_works'], shuffle=True) val_set = yaogan(mode='val', cls=train_args['training_cls'], input_transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=train_args['num_works'], shuffle=False) # test_set=yaogan(mode='test',cls=train_args['training_cls'],joint_transform=None, # input_transform=input_transform,target_transform=None) # test_loader=DataLoader(test_set,batch_size=1, # num_workers=train_args['num_works'], shuffle=False) optimizer = optim.Adadelta(model.parameters(), lr=train_args['lr']) ##define a weighted loss (0weight for 0 label) # weight=[0.09287939 ,0.02091968 ,0.02453979, 0.25752962 ,0.33731845, 1., # 0.09518322, 0.52794035 ,0.24298112 ,0.02657369, 0.15057124 ,0.36864611, # 0.25835161,0.16672758 ,0.40728756 ,0.00751281] """###############-------训练数据权重--------###############""" if train_args['weight'] is not None: weight = [0.1, 1.] weight = torch.Tensor(weight) else: weight = None criterion = nn.CrossEntropyLoss(weight=weight, reduction='elementwise_mean', ignore_index=-100).to(device) # criterion=nn.BCELoss(weight=weight,reduction='elementwise_mean').cuda() check_mkdir(train_args['ckpt_path']) check_mkdir(os.path.join(train_args['ckpt_path'], args['exp'])) check_mkdir( os.path.join(train_args['ckpt_path'], args['exp'], train_args['exp_name'])) open( os.path.join(train_args['ckpt_path'], args['exp'], train_args['exp_name'], str(time.time()) + '.txt'), 'w').write(str(train_args) + '\n\n') """###############-------start training--------###############""" for epoch in range(curr_epoch, train_args['epoch_num'] + 1): adjust_lr(optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch, train_args, device) val_loss = validate(val_loader, model, criterion, optimizer, restore, epoch, train_args, visualize, device) writer.close()