def main(args): assert os.path.isfile(args.checkpoint), "Checkpoint file not found: {}".format(args.checkpoint) args.cuda = not args.no_cuda and torch.cuda.is_available() test_transforms = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()]) # Initialize CLEVR Loader clevr_dataset_images = ClevrDatasetImages(args.clevr_dir, 'val', test_transforms) clevr_feat_extraction_loader = DataLoader(clevr_dataset_images, batch_size=args.batch_size, shuffle=False, num_workers=8, drop_last=True) args.features_dirs = './features' if not os.path.exists(args.features_dirs): os.makedirs(args.features_dirs) max_features = os.path.join(args.features_dirs, 'max_features.pickle') avg_features = os.path.join(args.features_dirs, 'avg_features.pickle') print('Building word dictionaries from all the words in the dataset...') dictionaries = utils.build_dictionaries(args.clevr_dir) print('Word dictionary completed!') args.qdict_size = len(dictionaries[0]) args.adict_size = len(dictionaries[1]) model = RN(args) if torch.cuda.device_count() > 1 and args.cuda: model = torch.nn.DataParallel(model) model.module.cuda() # call cuda() overridden method if args.cuda: model.cuda() # Load the model checkpoint print('==> loading checkpoint {}'.format(args.checkpoint)) checkpoint = torch.load(args.checkpoint) #removes 'module' from dict entries, pytorch bug #3805 checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()} model.load_state_dict(checkpoint) print('==> loaded checkpoint {}'.format(args.checkpoint)) max_features = open(max_features, 'wb') avg_features = open(avg_features, 'wb') extract_features_rl(clevr_feat_extraction_loader, max_features, avg_features, model, args)
def main(args): args.model_dirs = './model_{}_b{}_lr{}'.format(args.model, args.batch_size, args.lr) args.features_dirs = './features' if not os.path.exists(args.model_dirs): os.makedirs(args.model_dirs) args.test_results_dir = './test_results' if not os.path.exists(args.test_results_dir): os.makedirs(args.test_results_dir) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print('Building word dictionaries from all the words in the dataset...') dictionaries = utils.build_dictionaries(args.clevr_dir) print('Word dictionary completed!') print('Initializing CLEVR dataset...') if (not args.state_description): train_transforms = transforms.Compose([ transforms.Resize((128, 128)), transforms.Pad(8), transforms.RandomCrop((128, 128)), transforms.RandomRotation(2.8), # .05 rad transforms.ToTensor() ]) test_transforms = transforms.Compose( [transforms.Resize((128, 128)), transforms.ToTensor()]) clevr_dataset_train = ClevrDataset(args.clevr_dir, True, dictionaries, train_transforms) clevr_dataset_test = ClevrDataset(args.clevr_dir, False, dictionaries, test_transforms) # Use a weighted sampler for training: weights = clevr_dataset_train.answer_weights() sampler = torch.utils.data.sampler.WeightedRandomSampler( weights, len(weights)) # Initialize Clevr dataset loaders clevr_train_loader = DataLoader(clevr_dataset_train, batch_size=args.batch_size, sampler=sampler, num_workers=8, collate_fn=utils.collate_samples_image) clevr_test_loader = DataLoader(clevr_dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=8, collate_fn=utils.collate_samples_image) else: clevr_dataset_train = ClevrDatasetStateDescription( args.clevr_dir, True, dictionaries) clevr_dataset_test = ClevrDatasetStateDescription( args.clevr_dir, False, dictionaries) # Initialize Clevr dataset loaders clevr_train_loader = DataLoader( clevr_dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=8, collate_fn=utils.collate_samples_state_description) clevr_test_loader = DataLoader( clevr_dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=8, collate_fn=utils.collate_samples_state_description) print('CLEVR dataset initialized!') # Build the model args.qdict_size = len(dictionaries[0]) args.adict_size = len(dictionaries[1]) model = RN(args) if torch.cuda.device_count() > 1 and args.cuda: model = torch.nn.DataParallel(model) model.module.cuda() # call cuda() overridden method if args.cuda: model.cuda() start_epoch = 1 if args.resume: filename = args.resume if os.path.isfile(filename): print('==> loading checkpoint {}'.format(filename)) checkpoint = torch.load(filename) #removes 'module' from dict entries, pytorch bug #3805 #checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()} model.load_state_dict(checkpoint) print('==> loaded checkpoint {}'.format(filename)) start_epoch = int( re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1 if args.conv_transfer_learn: if os.path.isfile(args.conv_transfer_learn): # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel print('==> loading conv layer from {}'.format( args.conv_transfer_learn)) # pretrained dict is the dictionary containing the already trained conv layer pretrained_dict = torch.load(args.conv_transfer_learn) if torch.cuda.device_count() == 1: conv_dict = model.conv.state_dict() else: conv_dict = model.module.conv.state_dict() # filter only the conv layer from the loaded dictionary conv_pretrained_dict = { k.replace('conv.', '', 1): v for k, v in pretrained_dict.items() if 'conv.' in k } # overwrite entries in the existing state dict conv_dict.update(conv_pretrained_dict) # load the new state dict if torch.cuda.device_count() == 1: model.conv.load_state_dict(conv_dict) params = model.conv.parameters() else: model.module.conv.load_state_dict(conv_dict) params = model.module.conv.parameters() # freeze the weights for the convolutional layer by disabling gradient evaluation # for param in params: # param.requires_grad = False print("==> conv layer loaded!") else: print('Cannot load file {}'.format(args.conv_transfer_learn)) progress_bar = trange(start_epoch, args.epochs + 1) if args.test: # perform a single test print('Testing epoch {}'.format(start_epoch)) test(clevr_test_loader, model, start_epoch, dictionaries, args) else: # perform a full training optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4) print('Training ({} epochs) is starting...'.format(args.epochs)) for epoch in progress_bar: # TRAIN progress_bar.set_description('TRAIN') train(clevr_train_loader, model, optimizer, epoch, args) # TEST progress_bar.set_description('TEST') test(clevr_test_loader, model, epoch, dictionaries, args) # SAVE MODEL filename = 'RN_epoch_{:02d}.pth'.format(epoch) torch.save(model.state_dict(), os.path.join(args.model_dirs, filename))
def main(args): #load hyperparameters from configuration file with open(args.config) as config_file: hyp = json.load(config_file)['hyperparams'][args.model] #override configuration dropout if args.question_injection >= 0: hyp['question_injection_position'] = args.question_injection print('Loaded hyperparameters from configuration {}, model: {}: {}'.format( args.config, args.model, hyp)) assert os.path.isfile( args.checkpoint), "Checkpoint file not found: {}".format( args.checkpoint) args.cuda = not args.no_cuda and torch.cuda.is_available() # Initialize CLEVR Loader clevr_dataset_test = initialize_dataset( args.clevr_dir, True if args.set == 'train' else False, hyp['state_description']) clevr_feat_extraction_loader = reload_loaders(clevr_dataset_test, args.batch_size, hyp['state_description']) args.features_dirs = './features' if not os.path.exists(args.features_dirs): os.makedirs(args.features_dirs) files_dict = {} if args.extr_layer_idx >= 0: #g_layers features files_dict['max_features'] = \ open(os.path.join(args.features_dirs, '{}_2S-RN_max_features.pickle'.format(args.set,args.extr_layer_idx)),'wb') files_dict['avg_features'] = \ open(os.path.join(args.features_dirs, '{}_2S-RN_avg_features.pickle'.format(args.set,args.extr_layer_idx)),'wb') else: '''files_dict['flatconv_features'] = \ open(os.path.join(args.features_dirs, '{}_flatconv_features.pickle'.format(args.set)),'wb')''' files_dict['avgconv_features'] = \ open(os.path.join(args.features_dirs, '{}_RN_avg_features.pickle'.format(args.set)),'wb') files_dict['maxconv_features'] = \ open(os.path.join(args.features_dirs, '{}_RN_max_features.pickle'.format(args.set)),'wb') print('Building word dictionaries from all the words in the dataset...') dictionaries = utils.build_dictionaries(args.clevr_dir) print('Word dictionary completed!') args.qdict_size = len(dictionaries[0]) args.adict_size = len(dictionaries[1]) print('Cuda: {}'.format(args.cuda)) model = RN(args, hyp, extraction=True) if torch.cuda.device_count() > 1 and args.cuda: model = torch.nn.DataParallel(model) model.module.cuda() # call cuda() overridden method if args.cuda: model.cuda() # Load the model checkpoint print('==> loading checkpoint {}'.format(args.checkpoint)) checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) #removes 'module' from dict entries, pytorch bug #3805 #removes 'module' from dict entries, pytorch bug #3805 if torch.cuda.device_count() == 1 and any( k.startswith('module.') for k in checkpoint.keys()): print('Removing \'module.\' prefix') checkpoint = { k.replace('module.', ''): v for k, v in checkpoint.items() } if torch.cuda.device_count() > 1 and not any( k.startswith('module.') for k in checkpoint.keys()): print('Adding \'module.\' prefix') checkpoint = {'module.' + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint) print('==> loaded checkpoint {}'.format(args.checkpoint)) extract_features_rl(clevr_feat_extraction_loader, hyp['question_injection_position'], args.extr_layer_idx, hyp['lstm_hidden'], files_dict, model, args)
best_acc = 0 # LOSS # ----------------------------- if loss_type== 'bce': criterion = nn.BCELoss() else: criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) print("model and criterion loaded ...") checkpoint = torch.load('./saved_models/model_arr_20190223_21_0.0001_64_mse.ckpt') model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch'] model.eval() with torch.no_grad(): count = 0 total = 0 prec_count = 0 blankz_count = 0 for i in range(num_val_iters-1): # x, y = read_csv_batch('../../data/final_plans/final_val.csv', i, val_batch_size) x,y = read_batch(rows_val, i, val_batch_size) x = torch.tensor(x) x = nn.functional.interpolate(x, (75, 75))
class Task(): def __init__(self, args): print '#' * 60 print ' ' * 20 + ' Task Created ' + ' ' * 20 print '#' * 60 ###################################################################################################### # Parameters self.batchSize = args.batchSize self.lr = args.lr self.weightDecay = 1e-4 self.objNumMax = 30 self.wordEmbeddingDim = 128 self.instructionLength = 10 self.pinMemory = True self.dropout = False self.epoch = args.epoch self.epoch_i = 0 self.batchPrint = 100 self.batchModelSave = args.batchModelSave self.checkPoint = args.checkPoint # Path self.vocabularyPath = './data/vocabulary.json' self.trainDatasetPath = './data/generated_data_train.json' self.testDatasetPath = './data/generated_data_test.json' self.logPath = args.logPath # Dataset self.trainDataset = DatasetGenerator( datasetPath=self.trainDatasetPath, vocabularyPath=self.vocabularyPath) self.testDataset = DatasetGenerator(datasetPath=self.testDatasetPath, vocabularyPath=self.vocabularyPath) # Tokenizer self.tokenizer = Tokenizer(vocabPath=self.vocabularyPath) self.num_embedding = self.tokenizer.get_num_embedding() # DataLoader self.trainDataLoader = DataLoader(dataset=self.trainDataset, shuffle=True, batch_size=self.batchSize, num_workers=12, pin_memory=self.pinMemory) self.testDataLoader = DataLoader(dataset=self.testDataset, shuffle=False, batch_size=self.batchSize, num_workers=12, pin_memory=self.pinMemory) # calculate batch numbers self.trainBatchNum = int( np.ceil(len(self.trainDataset) / float(self.batchSize))) self.testBatchNum = int( np.ceil(len(self.testDataset) / float(self.batchSize))) # Create model self.RN = RN(num_embedding=self.num_embedding, embedding_dim=self.wordEmbeddingDim, obj_num_max=self.objNumMax) # Run task on all available GPUs if torch.cuda.is_available(): if torch.cuda.device_count() > 1: print("Use ", torch.cuda.device_count(), " GPUs") self.RN = nn.DataParallel(self.RN) self.RN = self.RN.cuda() print 'Model Created on GPUs.' # Optermizer self.optimizer = optim.Adam(self.RN.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=self.weightDecay) # Scheduler self.scheduler = ReduceLROnPlateau(self.optimizer, factor=0.1, patience=10, mode='min') # Loss Function self.loss = torch.nn.CrossEntropyLoss() # Load model if a checkPoint is given if self.checkPoint != "": self.load(self.checkPoint) # TensorboardX record self.writer = SummaryWriter() self.stepCnt_train = 1 self.stepCnt_test = 1 def train(self): print 'Training task begin.' print '----Batch Size: %d' % self.batchSize print '----Learning Rate: %f' % (self.lr) print '----Epoch: %d' % self.epoch print '----Log Path: %s' % self.logPath for self.epoch_i in range(self.epoch): self.epochTrain() self.test() def epochTrain(self): s = '#' * 30 + ' Epoch %3d / %3d ' % (self.epoch_i + 1, self.epoch) + '#' * 30 print s bar = tqdm(self.trainDataLoader) for idx, (objs_coordinate, objs_category, objs_category_idx, instruction, instruction_idx, target, data) in enumerate(bar): batchSize = objs_coordinate.shape[0] # to cuda if torch.cuda.is_available(): objs_coordinate = objs_coordinate.cuda() objs_category_idx = objs_category_idx.long().cuda() instruction_idx = instruction_idx.long().cuda() target = target.cuda() # Go through the model output = self.RN(objs_coordinate, objs_category_idx, instruction_idx) # calculate loss lossValue = self.loss(input=output, target=target) # Tensorboard record self.writer.add_scalar('Loss/Train', lossValue.item(), self.stepCnt_train) # print loss bar.set_description('Epoch: %d Loss: %f' % (self.epoch_i + 1, lossValue.item())) # Backward self.optimizer.zero_grad() lossValue.backward() self.optimizer.step() # self.scheduler.step(lossValue) # Save model if (idx + 1) % self.batchModelSave == 0: self.save(batchIdx=(idx + 1)) if idx % self.batchPrint == 0: output = output.detach().cpu().numpy() target = target.detach().cpu().numpy() s = '' for batch_i in range(batchSize): s += 'Target: ' s += str(target[batch_i]) s += ' Output: ' s += ', '.join( [str(i) for i in output[batch_i, :].tolist()]) s += ' ######### ' self.writer.add_text('Target & Output', s, self.stepCnt_train) self.writer.add_histogram('output', output, self.stepCnt_train) self.writer.add_histogram('target', target, self.stepCnt_train) for name, param in self.RN.named_parameters(): self.writer.add_histogram(name, param.clone().cpu().data.numpy(), self.stepCnt_train) self.stepCnt_train += 1 del lossValue def test(self): s = '#' * 28 + ' Test Epoch %3d / %3d ' % (self.epoch_i + 1, self.epoch) + '#' * 28 print s bar = tqdm(self.testDataLoader) for idx, (objs_coordinate, objs_category, objs_category_idx, instruction, instruction_idx, target, data) in enumerate(bar): batchSize = objs_coordinate.shape[0] # to cuda if torch.cuda.is_available(): objs_coordinate = objs_coordinate.cuda() objs_category_idx = objs_category_idx.cuda() instruction_idx = instruction_idx.cuda() target = target.cuda() # Go through the model output = self.RN(objs_coordinate, objs_category_idx, instruction_idx) # calculate loss lossValue = self.loss(input=output, target=target) # Tensorboard record self.writer.add_scalar('Loss/Test', lossValue.item(), self.stepCnt_test) # print loss bar.set_description('Epoch: %d Loss: %f' % (self.epoch_i + 1, lossValue.item())) # Save model if (idx + 1) % self.batchModelSave == 0: self.save(batchIdx=(idx + 1)) if idx % self.batchPrint == 0: output = output.detach().cpu().numpy() target = target.detach().cpu().numpy() s = '' for batch_i in range(batchSize): s += 'Target: ' s += str(target[batch_i]) s += ' Output: ' s += ', '.join( [str(i) for i in output[batch_i, :].tolist()]) s += ' ######### ' self.writer.add_text('Target & Output Test', s, self.stepCnt_test) self.writer.add_histogram('output', output, self.stepCnt_test) self.writer.add_histogram('target', target, self.stepCnt_test) # for name, param in self.RN.named_parameters(): # self.writer.add_histogram(name, param.clone().cpu().data.numpy(), self.stepCnt_test) self.stepCnt_test += 1 del lossValue def save(self, batchIdx=None): dirPath = os.path.join(self.logPath, 'models') if not os.path.exists(dirPath): os.mkdir(dirPath) if batchIdx is None: path = os.path.join(dirPath, 'Epoch-%03d-end.pth.tar' % (self.epoch_i + 1)) else: path = os.path.join( dirPath, 'Epoch-%03d-Batch-%04d.pth.tar' % (self.epoch_i + 1, batchIdx)) torch.save( { 'epochs': self.epoch_i + 1, 'batch_size': self.batchSize, 'lr': self.lr, 'weight_dacay': self.weightDecay, 'RN_model_state_dict': self.RN.state_dict() }, path) print 'Training log saved to %s' % path def load(self, path): modelCheckpoint = torch.load(path) self.RN.load_state_dict(modelCheckpoint['RN_model_state_dict']) print 'Load model from %s' % path
for img, relations, norelations in test_datasets: img = np.swapaxes(img, 0, 2) for qst, ans in zip(relations[0], relations[1]): rel_test.append((img, qst, ans)) for qst, ans in zip(norelations[0], norelations[1]): norel_test.append((img, qst, ans)) return (rel_train, rel_test, norel_train, norel_test) rel_train, rel_test, norel_train, norel_test = load_data() try: os.makedirs(model_dirs) except: print('directory {} already exists'.format(model_dirs)) if args.resume: filename = os.path.join(model_dirs, args.resume) if os.path.isfile(filename): print('==> loading checkpoint {}'.format(filename)) checkpoint = torch.load(filename) model.load_state_dict(checkpoint) print('==> loaded checkpoint {}'.format(filename)) for epoch in range(1, args.epochs + 1): train(epoch, rel_train, norel_train) test(epoch, rel_test, norel_test) model.save_model(epoch)
def main(args): #load hyperparameters from configuration file with open(args.config) as config_file: hyp = json.load(config_file)['hyperparams'][args.model] #override configuration dropout if args.dropout > 0: hyp['dropout'] = args.dropout if args.question_injection >= 0: hyp['question_injection_position'] = args.question_injection print('Loaded hyperparameters from configuration {}, model: {}: {}'.format( args.config, args.model, hyp)) args.model_dirs = '{}/model_{}_drop{}_bstart{}_bstep{}_bgamma{}_bmax{}_lrstart{}_'+ \ 'lrstep{}_lrgamma{}_lrmax{}_invquests-{}_clipnorm{}_glayers{}_qinj{}_fc1{}_fc2{}' args.model_dirs = args.model_dirs.format( args.exp_dir, args.model, hyp['dropout'], args.batch_size, args.bs_step, args.bs_gamma, args.bs_max, args.lr, args.lr_step, args.lr_gamma, args.lr_max, args.invert_questions, args.clip_norm, hyp['g_layers'], hyp['question_injection_position'], hyp['f_fc1'], hyp['f_fc2']) if not os.path.exists(args.model_dirs): os.makedirs(args.model_dirs) #create a file in this folder containing the overall configuration args_str = str(args) hyp_str = str(hyp) all_configuration = args_str + '\n\n' + hyp_str filename = os.path.join(args.model_dirs, 'config.txt') with open(filename, 'w') as config_file: config_file.write(all_configuration) args.features_dirs = '{}/features'.format(args.exp_dir) args.test_results_dir = '{}/test_results'.format(args.exp_dir) if not os.path.exists(args.test_results_dir): os.makedirs(args.test_results_dir) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print('Building word dictionaries from all the words in the dataset...') dictionaries = utils.build_dictionaries(args.clevr_dir, args.exp_dir) print('Word dictionary completed!') print('Initializing CLEVR dataset...') clevr_dataset_train, clevr_dataset_test = initialize_dataset( args.clevr_dir, args.exp_dir, dictionaries, hyp['state_description']) print('CLEVR dataset initialized!') # Build the model args.qdict_size = len(dictionaries[0]) args.adict_size = len(dictionaries[1]) model = RN(args, hyp) if torch.cuda.device_count() > 1 and args.cuda: model = torch.nn.DataParallel(model) model.module.cuda() # call cuda() overridden method if args.cuda: model.cuda() start_epoch = 1 if args.resume: filename = args.resume if os.path.isfile(filename): print('==> loading checkpoint {}'.format(filename)) checkpoint = torch.load(filename) #removes 'module' from dict entries, pytorch bug #3805 if torch.cuda.device_count() == 1 and any( k.startswith('module.') for k in checkpoint.keys()): checkpoint = { k.replace('module.', ''): v for k, v in checkpoint.items() } if torch.cuda.device_count() > 1 and not any( k.startswith('module.') for k in checkpoint.keys()): checkpoint = {'module.' + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint) print('==> loaded checkpoint {}'.format(filename)) start_epoch = int( re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1 if args.conv_transfer_learn: if os.path.isfile(args.conv_transfer_learn): # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel print('==> loading conv layer from {}'.format( args.conv_transfer_learn)) # pretrained dict is the dictionary containing the already trained conv layer pretrained_dict = torch.load(args.conv_transfer_learn) if torch.cuda.device_count() == 1: conv_dict = model.conv.state_dict() else: conv_dict = model.module.conv.state_dict() # filter only the conv layer from the loaded dictionary conv_pretrained_dict = { k.replace('conv.', '', 1): v for k, v in pretrained_dict.items() if 'conv.' in k } # overwrite entries in the existing state dict conv_dict.update(conv_pretrained_dict) # load the new state dict if torch.cuda.device_count() == 1: model.conv.load_state_dict(conv_dict) params = model.conv.parameters() else: model.module.conv.load_state_dict(conv_dict) params = model.module.conv.parameters() # freeze the weights for the convolutional layer by disabling gradient evaluation # for param in params: # param.requires_grad = False print("==> conv layer loaded!") else: print('Cannot load file {}'.format(args.conv_transfer_learn)) progress_bar = trange(start_epoch, args.epochs + 1) if args.test: # perform a single test print('Testing epoch {}'.format(start_epoch)) _, clevr_test_loader = reload_loaders(clevr_dataset_train, clevr_dataset_test, args.batch_size, args.test_batch_size, hyp['state_description']) test(clevr_test_loader, model, start_epoch, dictionaries, args) else: bs = args.batch_size # perform a full training #TODO: find a better solution for general lr scheduling policies candidate_lr = args.lr * args.lr_gamma**(start_epoch - 1 // args.lr_step) lr = candidate_lr if candidate_lr <= args.lr_max else args.lr_max optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-4) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, min_lr=1e-6, verbose=True) scheduler = lr_scheduler.StepLR(optimizer, args.lr_step, gamma=args.lr_gamma) scheduler.last_epoch = start_epoch print('Training ({} epochs) is starting...'.format(args.epochs)) for epoch in progress_bar: if ((args.bs_max > 0 and bs < args.bs_max) or args.bs_max < 0) and (epoch % args.bs_step == 0 or epoch == start_epoch): bs = math.floor(args.batch_size * (args.bs_gamma**(epoch // args.bs_step))) if bs > args.bs_max and args.bs_max > 0: bs = args.bs_max clevr_train_loader, clevr_test_loader = reload_loaders( clevr_dataset_train, clevr_dataset_test, bs, args.test_batch_size, hyp['state_description']) #restart optimizer in order to restart learning rate scheduler #for param_group in optimizer.param_groups: # param_group['lr'] = args.lr #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, step, min_lr) print('Dataset reinitialized with batch size {}'.format(bs)) if ((args.lr_max > 0 and scheduler.get_lr()[0] < args.lr_max) or args.lr_max < 0): scheduler.step() print('Current learning rate: {}'.format( optimizer.param_groups[0]['lr'])) # TRAIN progress_bar.set_description('TRAIN') train(clevr_train_loader, model, optimizer, epoch, args) # TEST progress_bar.set_description('TEST') test(clevr_test_loader, model, epoch, dictionaries, args) # SAVE MODEL filename = 'RN_epoch_{:02d}.pth'.format(epoch) torch.save(model.state_dict(), os.path.join(args.model_dirs, filename))