def train(): # initialize datasets and loaders trainsets, valsets, testsets = args['data.train'], args['data.val'], args[ 'data.test'] train_loader = MetaDatasetBatchReader('train', trainsets, valsets, testsets, batch_size=args['train.batch_size']) val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets) # initialize model and optimizer num_train_classes = sum(list(train_loader.dataset_to_n_cats.values())) model = get_model(num_train_classes, args) optimizer = get_optimizer(model, args, params=model.get_parameters()) # Restoring the last checkpoint checkpointer = CheckPointer(args, model, optimizer=optimizer) if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']: start_iter, best_val_loss, best_val_acc =\ checkpointer.restore_model(ckpt='last', strict=False) else: print('No checkpoint restoration') best_val_loss = 999999999 best_val_acc = start_iter = 0 # define learning rate policy if args['train.lr_policy'] == "step": lr_manager = UniformStepLR(optimizer, args, start_iter) elif "exp_decay" in args['train.lr_policy']: lr_manager = ExpDecayLR(optimizer, args, start_iter) elif "cosine" in args['train.lr_policy']: lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter) # defining the summary writer writer = SummaryWriter(checkpointer.model_path) # Training loop max_iter = args['train.max_iter'] epoch_loss = {name: [] for name in trainsets} epoch_acc = {name: [] for name in trainsets} config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as session: for i in tqdm(range(max_iter)): if i < start_iter: continue optimizer.zero_grad() sample = train_loader.get_train_batch(session) batch_dataset = sample['dataset_name'] dataset_id = sample['dataset_ids'][0].detach().cpu().item() logits = model.forward(sample['images']) labels = sample['labels'] batch_loss, stats_dict, _ = cross_entropy_loss(logits, labels) epoch_loss[batch_dataset].append(stats_dict['loss']) epoch_acc[batch_dataset].append(stats_dict['acc']) batch_loss.backward() optimizer.step() lr_manager.step(i) if (i + 1) % 200 == 0: for dataset_name in trainsets: writer.add_scalar(f"loss/{dataset_name}-train_acc", np.mean(epoch_loss[dataset_name]), i) writer.add_scalar(f"accuracy/{dataset_name}-train_acc", np.mean(epoch_acc[dataset_name]), i) epoch_loss[dataset_name], epoch_acc[dataset_name] = [], [] writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], i) # Evaluation inside the training loop if (i + 1) % args['train.eval_freq'] == 0: model.eval() dataset_accs, dataset_losses = [], [] for valset in valsets: dataset_id = train_loader.dataset_name_to_dataset_id[ valset] val_losses, val_accs = [], [] for j in tqdm(range(args['train.eval_size'])): with torch.no_grad(): sample = val_loader.get_validation_task( session, valset) context_features = model.embed( sample['context_images']) target_features = model.embed( sample['target_images']) context_labels = sample['context_labels'] target_labels = sample['target_labels'] _, stats_dict, _ = prototype_loss( context_features, context_labels, target_features, target_labels) val_losses.append(stats_dict['loss']) val_accs.append(stats_dict['acc']) # write summaries per validation set dataset_acc, dataset_loss = np.mean( val_accs) * 100, np.mean(val_losses) dataset_accs.append(dataset_acc) dataset_losses.append(dataset_loss) writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss, i) writer.add_scalar(f"accuracy/{valset}/val_acc", dataset_acc, i) print( f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}" ) # write summaries averaged over datasets avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean( dataset_accs) writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i) writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i) # saving checkpoints if avg_val_acc > best_val_acc: best_val_loss, best_val_acc = avg_val_loss, avg_val_acc is_best = True print('Best model so far!') else: is_best = False checkpointer.save_checkpoint(i, best_val_acc, best_val_loss, is_best, optimizer=optimizer, state_dict=model.get_state_dict()) model.train() print(f"Trained and evaluated at {i}") writer.close() if start_iter < max_iter: print( f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%""" ) else: print( f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}""" )
def main(args): args.sample_size = 28 args.validation_split = 0.0 # data for testing if args.dataset == 'MNIST': args.sample_size = 28 elif args.dataset == 'SOP' or 'Shopee': if args.model == 'resnet': args.sample_size = 224 elif args.model == 'vgg' or args.model == 'vgg_attn': args.sample_size = 224 elif args.model == 'inception': args.sample_size = 299 else: args.sample_size = 224 spatial_transform_test = get_test_transform(args) crop_transform = get_crop_transform(args) if args.dataset == 'MNIST': test_data_loader = data_loading.CroppedMNISTLoader( args, crop_transform=crop_transform, spatial_transform=spatial_transform_test, training=False) elif args.dataset == 'SOP': test_data_loader = data_loading.CroppedSOPLoader( args, crop_transform=crop_transform, spatial_transform=spatial_transform_test, training=False) elif args.dataset == 'Shopee': test_data_loader = data_loading.ShopeeDataLoader( args, crop_transform=crop_transform, spatial_transform=spatial_transform_test, training=False) args.n_classes = test_data_loader.n_classes data_ratio = sum([1 for sample in test_data_loader.dataset.samples if sample[1] == 0]) / \ sum([1 for sample in test_data_loader.dataset.samples if sample[1] == 1]) print("normal data/cropped data ratio: {}".format(data_ratio)) # prepare the model for testing model, parameters = get_model(args) model = model.to(device) model.eval() test_logger = Logger( os.path.join(args.log_path, 'test_{}.log'.format(args.dataset)), ['batch', 'loss', 'acc']) revision_logger = Logger( os.path.join(args.log_path, 'test_config_{}.log'.format(args.dataset)), [ 'dataset', 'dataset_size', 'train_test_split', 'n_classes', 'model', 'model_depth', 'test_batch_size', 'crop_scale', 'cropped_data_ratio', 'shuffle' ]) revision_logger.log({ 'dataset': args.dataset, 'dataset_size': args.dataset_size, 'train_test_split': args.train_test_split, 'n_classes': args.n_classes, 'model': args.model, 'model_depth': args.model_depth, 'test_batch_size': args.batch_size, 'crop_scale': args.crop_scale, 'cropped_data_ratio': args.cropped_data_ratio, 'shuffle': args.shuffle }) # load the trained model weights print('loading checkpoint {}'.format(args.model_path)) checkpoint = torch.load(args.model_path, map_location=device) assert args.arch == checkpoint['arch'] model.load_state_dict(checkpoint['state_dict']) criterion = cross_entropy_loss() accuracies = AverageMeter() losses = AverageMeter() visualizer = TestVisualizer(test_data_loader.classes, test_data_loader.rgb_mean, test_data_loader.rgb_std, 5, args) show_misclassified = False block = False with torch.no_grad(): misclassified_images_full = [] misclassified_true_labels_full = [] misclassified_predictions_full = [] for i, (inputs, targets) in enumerate(test_data_loader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) losses.update(loss.item(), inputs.size(0)) acc = calculate_accuracy(outputs, targets) accuracies.update(acc, inputs.size(0)) _, predictions = outputs.topk(1, 1, True) fails = predictions.squeeze() != targets predictions = predictions.squeeze().cpu().numpy() print("Number of misclassifications: {}/{}".format( fails.sum(), len(inputs))) if show_misclassified: misclassified_idxs = [i for i, x in enumerate(fails) if x] misclassified_images = [ inputs[idx] for idx in misclassified_idxs ] misclassified_true_labels = [ targets[idx] for idx in misclassified_idxs ] misclassified_predictions = [ predictions[idx] for idx in misclassified_idxs ] misclassified_images_full.extend(misclassified_images) misclassified_true_labels_full.extend( misclassified_true_labels) misclassified_predictions_full.extend( misclassified_predictions) if args.show_test_images and i % args.plot_interval == 0: # visualizer = TestVisualizer(test_data_loader.classes, test_data_loader.rgb_mean, # test_data_loader.rgb_std, 5, False, # args) visualizer.make_grid(inputs, targets, predictions) visualizer.show(block) test_logger.log({ 'batch': i + 1, 'loss': losses.avg, 'acc': accuracies.avg }) print('Batch: [{0}/{1}]\t' 'Loss {loss.value:.4f} (avg {loss.avg:.4f})\t' 'Acc {acc.value:.3f} (avg {acc.avg:.3f})'.format( i + 1, len(test_data_loader), loss=losses, acc=accuracies)) if show_misclassified: keep_going = True n_images = 5 start = 0 while keep_going: error_visualizer = TestVisualizer(test_data_loader.classes, test_data_loader.rgb_mean, test_data_loader.rgb_std, n_images, args) if n_images**2 < len(misclassified_images_full[start:]): images, labels, predictions = misclassified_images_full[start:start+n_images**2], \ misclassified_true_labels_full[start:start+n_images**2], \ misclassified_predictions_full[start:start+n_images**2] else: images, labels, predictions = misclassified_images_full[start:], \ misclassified_true_labels_full[start:], \ misclassified_predictions_full[start:] keep_going = False error_visualizer.make_grid(images, labels, predictions) error_visualizer.show(True) start += n_images**2 print('Test log written to {}'.format(test_logger.log_file))
def train(): # initialize datasets and loaders trainsets, valsets, testsets = args['data.train'], args['data.val'], args[ 'data.test'] train_loaders = [] num_train_classes = dict() kd_weight_annealing = dict() for t_indx, trainset in enumerate(trainsets): train_loaders.append( MetaDatasetBatchReader('train', [trainset], valsets, testsets, batch_size=BATCHSIZES[trainset])) num_train_classes[trainset] = train_loaders[t_indx].num_classes( 'train') # setting up knowledge distillation losses weights annealing kd_weight_annealing[trainset] = WeightAnnealing( T=int(args['train.cosine_anneal_freq'] * KDANNEALING[trainset])) val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets) # initialize model and optimizer model = get_model(list(num_train_classes.values()), args) model_name_temp = args['model.name'] # KL-divergence loss criterion_div = DistillKL(T=4) # get a MTL model initialized by ImageNet pretrained model and deactivate the pretrained flag args['model.pretrained'] = False optimizer = get_optimizer(model, args, params=model.get_parameters()) # adaptors for aligning features between MDL and SDL models adaptors = adaptor(num_datasets=len(trainsets), dim_in=512, opt=args['adaptor.opt']).to(device) optimizer_adaptor = torch.optim.Adam(adaptors.parameters(), lr=0.1, weight_decay=5e-4) # loading single domain learning networks extractor_domains = trainsets dataset_models = DATASET_MODELS_DICT[args['model.backbone']] embed_many = get_domain_extractors(extractor_domains, dataset_models, args, num_train_classes) # restoring the last checkpoint args['model.name'] = model_name_temp checkpointer = CheckPointer(args, model, optimizer=optimizer) if os.path.isfile(checkpointer.out_last_ckpt) and args['train.resume']: start_iter, best_val_loss, best_val_acc =\ checkpointer.restore_out_model(ckpt='last') else: print('No checkpoint restoration') best_val_loss = 999999999 best_val_acc = start_iter = 0 # define learning rate policy if args['train.lr_policy'] == "step": lr_manager = UniformStepLR(optimizer, args, start_iter) lr_manager_ad = UniformStepLR(optimizer_adaptor, args, start_iter) elif "exp_decay" in args['train.lr_policy']: lr_manager = ExpDecayLR(optimizer, args, start_iter) lr_manager_ad = ExpDecayLR(optimizer_adaptor, args, start_iter) elif "cosine" in args['train.lr_policy']: lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter) lr_manager_ad = CosineAnnealRestartLR(optimizer_adaptor, args, start_iter) # defining the summary writer writer = SummaryWriter(checkpointer.out_path) # Training loop max_iter = args['train.max_iter'] epoch_loss = {name: [] for name in trainsets} epoch_kd_f_loss = {name: [] for name in trainsets} epoch_kd_p_loss = {name: [] for name in trainsets} epoch_acc = {name: [] for name in trainsets} epoch_val_loss = {name: [] for name in valsets} epoch_val_acc = {name: [] for name in valsets} config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = False with tf.compat.v1.Session(config=config) as session: for i in tqdm(range(max_iter)): if i < start_iter: continue optimizer.zero_grad() optimizer_adaptor.zero_grad() samples = [] images = dict() num_samples = [] # loading images and labels for t_indx, (name, train_loader) in enumerate( zip(trainsets, train_loaders)): sample = train_loader.get_train_batch(session) samples.append(sample) images[name] = sample['images'] num_samples.append(sample['images'].size(0)) logits, mtl_features = model.forward(torch.cat(list( images.values()), dim=0), num_samples, kd=True) stl_features, stl_logits = embed_many(images, return_type='list', kd=True, logits=True) mtl_features = adaptors(mtl_features) batch_losses, stats_dicts = [], [] kd_f_losses = 0 kd_p_losses = 0 for t_indx, trainset in enumerate(trainsets): batch_loss, stats_dict, _ = cross_entropy_loss( logits[t_indx], samples[t_indx]['labels']) batch_losses.append(batch_loss * LOSSWEIGHTS[trainset]) stats_dicts.append(stats_dict) batch_dataset = samples[t_indx]['dataset_name'] epoch_loss[batch_dataset].append(stats_dict['loss']) epoch_acc[batch_dataset].append(stats_dict['acc']) ft, fs = torch.nn.functional.normalize( stl_features[t_indx], p=2, dim=1, eps=1e-12), torch.nn.functional.normalize( mtl_features[t_indx], p=2, dim=1, eps=1e-12) kd_f_losses_ = distillation_loss(fs, ft.detach(), opt='kernelcka') kd_p_losses_ = criterion_div(logits[t_indx], stl_logits[t_indx]) kd_weight = kd_weight_annealing[trainset]( t=i, opt='linear') * KDFLOSSWEIGHTS[trainset] bam_weight = kd_weight_annealing[trainset]( t=i, opt='linear') * KDPLOSSWEIGHTS[trainset] if kd_weight > 0: kd_f_losses = kd_f_losses + kd_f_losses_ * kd_weight if bam_weight > 0: kd_p_losses = kd_p_losses + kd_p_losses_ * bam_weight epoch_kd_f_loss[batch_dataset].append(kd_f_losses_.item()) epoch_kd_p_loss[batch_dataset].append(kd_p_losses_.item()) batch_loss = torch.stack(batch_losses).sum() kd_f_loss = kd_f_losses * args['train.sigma'] kd_p_loss = kd_p_losses * args['train.beta'] batch_loss = batch_loss + kd_f_loss + kd_p_loss batch_loss.backward() optimizer.step() optimizer_adaptor.step() lr_manager.step(i) lr_manager_ad.step(i) if (i + 1) % 200 == 0: for dataset_name in trainsets: writer.add_scalar(f"loss/{dataset_name}-train_loss", np.mean(epoch_loss[dataset_name]), i) writer.add_scalar(f"accuracy/{dataset_name}-train_acc", np.mean(epoch_acc[dataset_name]), i) writer.add_scalar( f"kd_f_loss/{dataset_name}-train_kd_f_loss", np.mean(epoch_kd_f_loss[dataset_name]), i) writer.add_scalar( f"kd_p_loss/{dataset_name}-train_kd_p_loss", np.mean(epoch_kd_p_loss[dataset_name]), i) epoch_loss[dataset_name], epoch_acc[ dataset_name], epoch_kd_f_loss[ dataset_name], epoch_kd_p_loss[ dataset_name] = [], [], [], [] writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], i) # Evaluation inside the training loop if (i + 1) % args['train.eval_freq'] == 0: model.eval() dataset_accs, dataset_losses = [], [] for valset in valsets: val_losses, val_accs = [], [] for j in tqdm(range(args['train.eval_size'])): with torch.no_grad(): sample = val_loader.get_validation_task( session, valset) context_features = model.embed( sample['context_images']) target_features = model.embed( sample['target_images']) context_labels = sample['context_labels'] target_labels = sample['target_labels'] _, stats_dict, _ = prototype_loss( context_features, context_labels, target_features, target_labels) val_losses.append(stats_dict['loss']) val_accs.append(stats_dict['acc']) # write summaries per validation set dataset_acc, dataset_loss = np.mean( val_accs) * 100, np.mean(val_losses) dataset_accs.append(dataset_acc) dataset_losses.append(dataset_loss) epoch_val_loss[valset].append(dataset_loss) epoch_val_acc[valset].append(dataset_acc) writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss, i) writer.add_scalar(f"accuracy/{valset}/val_acc", dataset_acc, i) print( f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}" ) # write summaries averaged over datasets avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean( dataset_accs) writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i) writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i) # saving checkpoints if avg_val_acc > best_val_acc: best_val_loss, best_val_acc = avg_val_loss, avg_val_acc is_best = True print('Best model so far!') else: is_best = False extra_dict = { 'epoch_loss': epoch_loss, 'epoch_acc': epoch_acc, 'epoch_val_loss': epoch_val_loss, 'epoch_val_acc': epoch_val_acc, 'adaptors': adaptors.state_dict(), 'optimizer_adaptor': optimizer_adaptor.state_dict() } checkpointer.save_checkpoint(i, best_val_acc, best_val_loss, is_best, optimizer=optimizer, state_dict=model.get_state_dict(), extra=extra_dict) model.train() print(f"Trained and evaluated at {i}") writer.close() if start_iter < max_iter: print( f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%""" ) else: print( f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}""" )
def main(args): # args.n_epochs = 10 # args.crop_scale = 0.3 # args.batch_size = 128 args.normal_data_ratio = 0.9 if args.dataset == 'MNIST': args.sample_size = 28 elif args.dataset == 'SOP' or 'Shopee': if args.model == 'resnet': args.sample_size = 224 elif args.model == 'vgg' or args.model == 'vgg_attn': args.sample_size = 224 elif args.model == 'inception': args.sample_size = 299 else: args.sample_size = 224 spatial_transform_train = get_train_transform(args) crop_transform = get_crop_transform(args) if args.dataset == 'MNIST': train_data_loader = CroppedMNISTLoader(args, crop_transform=crop_transform, spatial_transform=spatial_transform_train, training=True) elif args.dataset == 'SOP': train_data_loader = CroppedSOPLoader(args, crop_transform=crop_transform, spatial_transform=spatial_transform_train, training=True) elif args.dataset == 'Shopee': train_data_loader = ShopeeDataLoader(args, crop_transform=crop_transform, spatial_transform=spatial_transform_train, training=True) valid_data_loader = train_data_loader.split_validation() args.n_channels = train_data_loader.n_channels args.n_classes = train_data_loader.n_classes model, parameters = get_model(args) model = model.to(device) criterion = losses.cross_entropy_loss() train_logger = Logger( os.path.join(args.log_path, 'train.log'), ['epoch', 'loss', 'acc', 'lr']) train_batch_logger = Logger( os.path.join(args.log_path, 'train_batch.log'), ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr']) valid_logger = Logger( os.path.join(args.log_path, 'val.log'), ['epoch', 'loss', 'acc']) revision_logger = Logger( os.path.join(args.log_path, 'revision_info.log'), ['dataset', 'dataset_size', 'train_test_split', 'model', 'model_depth', 'resume', 'resume_path', 'batch_size', 'n_epochs', 'sample_size', 'crop_scale', 'crop_transform', 'cropped_data_ratio']) revision_logger.log({ 'dataset': args.dataset, 'dataset_size': args.dataset_size, 'train_test_split': args.train_test_split, 'model': args.model, 'model_depth': args.model_depth, 'resume': args.resume, 'resume_path': args.resume_path, 'batch_size': args.batch_size, 'n_epochs': args.n_epochs, 'sample_size': args.sample_size, 'crop_scale': args.crop_scale, 'crop_transform': crop_transform.__class__.__name__, 'cropped_data_ratio': args.cropped_data_ratio }) if args.nesterov: dampening = 0 else: dampening = args.dampening optimizer = optim.SGD( parameters, lr=args.learning_rate, momentum=args.momentum, dampening=dampening, weight_decay=args.weight_decay, nesterov=args.nesterov) scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=args.lr_patience) trainer = Trainer(model, criterion, optimizer, args, device, train_data_loader, lr_scheduler=scheduler, valid_data_loader=valid_data_loader, train_logger=train_logger, batch_logger=train_batch_logger, valid_logger=valid_logger) trainer.train()