def tb_logger(self): return tensorboard_logger.Logger(self.workspace.tensorboard)
def main(model_name, output_dir, batch_size=320, num_epochs=15, valid_int=1, checkpoint=None, init_weights=None, num_workers=5, pre_train=False, train_subset_path=None, kwargs_str=None): # Data loading _logger.info("Reading WebVision Dataset") subset = None if train_subset_path is None else np.load(train_subset_path) train_db, val_db, _ = wvc_data.get_datasets(pre_train, is_lmdb=False, subset=subset) balanced_sampler = WeightedRandomSampler(train_db.sample_weight, train_db.sample_weight.size, replacement=True) train_data_loader = dataloader.DataLoader(train_db, batch_size=batch_size, sampler=balanced_sampler, num_workers=num_workers, pin_memory=True) val_data_loader = dataloader.DataLoader(val_db, batch_size=batch_size, num_workers=num_workers, pin_memory=True) # Model building _logger.info("Building Model: {}".format(model_name)) kwargs_dic = wvc_utils.get_kwargs_dic(kwargs_str) _logger.info("Arguments: {}".format(kwargs_dic)) model = wvc_model.model_factory(model_name, kwargs_dic) if pre_train: model = wvc_model.PermLearning(model) _logger.info("Running model with {} GPUS and {} data workers".format( device_count(), num_workers)) model = torch.nn.DataParallel(model).cuda() # Optionally load weights if init_weights is not None: _logger.info("Loading weights from {}".format(init_weights)) init_weights = torch.load(init_weights) model.load_state_dict(init_weights['state_dict']) # Objective and Optimizer _logger.info("Setting up loss function and optimizer") criterion = torch.nn.CrossEntropyLoss( ) if not pre_train else wvc_model.WeightedMultiLabelBinaryCrossEntropy( False) criterion = criterion.cuda() init_lr = float(kwargs_dic.get("lr", 1e-1)) optimizer = torch.optim.SGD([{ 'params': model.module.features.parameters(), 'lr': init_lr }, { 'params': model.module.classifier.parameters(), 'lr': init_lr }], lr=init_lr, momentum=float(kwargs_dic.get("momentum", 0.9)), weight_decay=float( kwargs_dic.get("weight_decay", 1e-4))) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=int(kwargs_dic.get('lr_step', 4)), gamma=float(kwargs_dic.get('lr_decay', 0.1))) # Optionally resume from a checkpoint if checkpoint is not None: _logger.info("Resume training from {}".format(checkpoint)) checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 best_acc5 = checkpoint['best_acc5'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.step(start_epoch - 1) else: start_epoch, best_acc5 = 0, 0.0 # Training and Validation loop _logger.info("Training...") tb_logger = tb_log.Logger(output_dir) metric_func = wvc_model.multilabel_metrics if pre_train else wvc_model.top_k_acc best_metric_name = 'ml_acc' if pre_train else "acc5" for epoch in range(start_epoch, num_epochs): # update learning rate for this epoch scheduler.step() # train for one epoch tr_loss, metrics = wvc_model.train(train_data_loader, model, criterion, optimizer, metric_func, epoch) _logger.info("Epoch Train {}/{}: tr_loss={:.3f}, {}".format( epoch, num_epochs, tr_loss, ", ".join( ["tr_{}={:.3f}".format(k, v) for k, v in metrics.items()]))) tb_logger.log_value('tr_loss', tr_loss, epoch) for k, v in metrics.items(): tb_logger.log_value('tr_{}'.format(k), v, epoch) # Validation if (epoch + 1) % valid_int == 0: _logger.info("Validating...") val_loss, metrics = wvc_model.validate(val_data_loader, model, criterion, metric_func, epoch) _logger.info("Epoch Validation {}/{}: val_loss={:.3f}, {}".format( epoch, num_epochs, val_loss, ", ".join([ "val_{}={:.3f}".format(k, v) for k, v in metrics.items() ]))) tb_logger.log_value('val_loss', val_loss, epoch) for k, v in metrics.items(): tb_logger.log_value('val_{}'.format(k), v, epoch) curr_metric_val = metrics[best_metric_name] # save checkpoint model_ckpt_name = 'checkpoint.pth.tar' _logger.info("Save model checkpoint to {}".format( os.path.join(output_dir, model_ckpt_name))) is_best = curr_metric_val >= best_acc5 best_acc5 = max(curr_metric_val, best_acc5) wvc_model.save_checkpoint( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_acc5': best_acc5 }, is_best, output_dir, model_ckpt_name)