def run_single_split(args, split, output_dir="."): train_metrics_per_epoch = [] test_metrics_per_epoch = [] best_test_metrics = None test_loader = split['test'] train_loader = split['splits'][0]['train'] assert(len(split['splits']) == 1) # Just one split. model = classifiers.versions[args.classifier_type](args, loadable_state_dict=None) if args.cuda: model = model.cuda() if args.dataparallel: print("Using dataparallel") model = nn.DataParallel(model) # New Optimizer params = list(model.parameters()) optimizer = optim.Adam( params, lr=args.lr, betas=(0.5, 0.9), weight_decay=args.weight_decay ) old_loss = None min_loss = None for epoch in range(args.max_epochs): epoch_tic = time.time() train_metrics = train_one_epoch(args, model, optimizer, train_loader) # , tobreak=(epoch==(0))) train_metrics_per_epoch.append(train_metrics) test_metrics = evaluate(args, model, test_loader) test_metrics_per_epoch.append(test_metrics) is_best = (best_test_metrics is None) or test_metrics['accuracy'] >= best_test_metrics['accuracy'] if is_best: chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "best.checkpoint")) best_test_metrics = test_metrics chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "last.checkpoint")) print( "[Epoch {}/{}] train-loss={:.4f} train-acc={:.4f} test-acc={:.4f} time={:.2f}".format( epoch, args.max_epochs, train_metrics["loss"], train_metrics["accuracy"], test_metrics["accuracy"], time.time() - epoch_tic, ) ) torch.save(train_metrics_per_epoch, os.path.join(output_dir, "train_metrics_per_epoch.checkpoint")) torch.save(test_metrics_per_epoch, os.path.join(output_dir, "test_metrics_per_epoch.checkpoint")) torch.save(best_test_metrics, os.path.join(output_dir, "best_test_metrics.checkpoint")) return train_metrics_per_epoch, test_metrics_per_epoch, best_test_metrics
def run_single_split(args, split, output_dir="."): train_metrics_per_epoch = [] test_metrics_per_epoch = [] best_test_metrics = None test_dset = split['test'] train_dset = split['splits'][0]['train'] assert(len(split['splits']) == 1) # Just one split. if args.debug: test_dset.n = DEBUG_N train_dset.n = DEBUG_N print("Starting preloading") tic = time.time() if args.not_lazy: train_dset.preload(count=args.num_workers) test_dset.preload(count=args.num_workers) print("Preloading took {}s".format(time.time() - tic)) test_loader = torch.utils.data.DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) model = classifiers.versions[args.classifier_type](args, loadable_state_dict=None) # import pdb; pdb.set_trace() if args.cuda: model = model.cuda() if args.dataparallel: print("Using dataparallel") model = nn.DataParallel(model) # New Optimizer params = list(model.parameters()) optimizer = optim.Adam( params, lr=args.lr, betas=(0.5, 0.9), weight_decay=args.weight_decay ) # New scheduler if classifiers.scheduled[args.classifier_type]: scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[5, 25], # Start at 0.001 -> 0.0001, -> 0.00001 gamma=0.1, ) else: scheduler = None old_loss = None min_loss = None for epoch in range(args.max_epochs): epoch_tic = time.time() if scheduler is not None: scheduler.step() train_metrics = train_one_epoch(args, model, optimizer, train_loader) # , tobreak=(epoch==(0))) train_metrics_per_epoch.append(train_metrics) test_metrics = evaluate(args, model, test_loader) test_metrics_per_epoch.append(test_metrics) is_best = (best_test_metrics is None) or test_metrics['accuracy'] >= best_test_metrics['accuracy'] if is_best: chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "best.checkpoint")) best_test_metrics = test_metrics chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "last.checkpoint")) # breaking logic: Not yet implemented # new_loss = train_metrics['loss'] # if min_loss is None: # min_loss = new_loss # else: # min_loss = min(new_loss, min_loss) # # to_break = ( # (epoch > args.min_epochs) and # (old_loss is not None) and # ((abs(new_loss - old_loss) < args.loss_thr)) and # ((abs(new_loss - min_loss) < args.loss_thr)) # ) # old_loss = new_loss # if to_break: # break print( "[Epoch {}/{}] train-loss={:.4f} train-acc={:.4f} test-acc={:.4f} time={:.2f}".format( epoch, args.max_epochs, train_metrics["loss"], train_metrics["accuracy"], test_metrics["accuracy"], time.time() - epoch_tic, ) ) torch.save(train_metrics_per_epoch, os.path.join(output_dir, "train_metrics_per_epoch.checkpoint")) torch.save(test_metrics_per_epoch, os.path.join(output_dir, "test_metrics_per_epoch.checkpoint")) torch.save(best_test_metrics, os.path.join(output_dir, "best_test_metrics.checkpoint")) print("Starting unload") tic = time.time() if args.not_lazy: train_dset.unload() test_dset.unload() print("Unloaded in {}s".format(time.time() - tic)) return train_metrics_per_epoch, test_metrics_per_epoch, best_test_metrics
def train_skeletal_model(args, dataset, train_loader, test_loader): output_dir = os.path.join(args.base_output) model = graph_rnn.GraphRNN( discrete_feature_dim=dataset.discrete_feature_dim, continuous_feature_dim=dataset.continuous_feature_dim, max_vertex_num=dataset.max_vertex_num, rnn_hidden_size=args.hidden_size, rnn_num_layers=args.num_layers ) model = model.to(args.device) if args.dataparallel: raise NotImplementedError('Check if nn.DataParallel works with RNN') params = list(model.parameters()) optimizer = optim.Adam( params, lr=args.lr, weight_decay=args.weight_decay ) metrics = defaultdict(list) for epoch_idx in range(args.epochs): print('Starting epoch {}'.format(epoch_idx)) epoch_metrics = defaultdict(list) tic = time.time() for bidx, (G_t, G_tp1, mask) in enumerate(train_loader): G_t = G_t.to(args.device) G_tp1 = G_tp1.to(args.device) discrete_hat, continuous_hat, adj_hat = G_tp1_hat = model(G_t) eos_loss, adj_loss, pos_loss = skeletal_losses(G_tp1_hat, G_tp1, dataset, mask) loss = eos_loss + adj_loss + pos_loss optimizer.zero_grad() loss.backward() optimizer.step() epoch_metrics['eos_loss'].append(eos_loss.item()) epoch_metrics['adj_loss'].append(adj_loss.item()) epoch_metrics['pos_loss'].append(pos_loss.item()) epoch_metrics['loss'].append(loss.item()) metrics['eos_loss'].append(np.mean(epoch_metrics['eos_loss'])) metrics['adj_loss'].append(np.mean(epoch_metrics['adj_loss'])) metrics['pos_loss'].append(np.mean(epoch_metrics['pos_loss'])) metrics['loss'].append(np.mean(epoch_metrics['loss'])) print('[{:.2f}s] Epoch {}: losses={:.3f} eos, {:.3f} adj, {:.3f} pos = {:.3f} total'.format( time.time() - tic, epoch_idx, metrics['eos_loss'][epoch_idx], metrics['adj_loss'][epoch_idx], metrics['pos_loss'][epoch_idx], metrics['loss'][epoch_idx], )) # Eval and save if necessary. if utils.periodic_integer_delta(epoch_idx, args.eval_every): test_metrics = test(args, dataset, model, test_loader, prefix='Test Dataset, Epoch {}'.format(epoch_idx)) for k, v in test_metrics.items(): metrics['test_{}_epoch{}'.format(k, epoch_idx)] = v if utils.periodic_integer_delta(epoch_idx, args.save_every): checkpoint_path = os.path.join(output_dir, "last.checkpoint") print('Saving model to {}'.format(checkpoint_path)) chk = utils.make_checkpoint(model, optimizer, epoch_idx) chk['args'] = vars(args) torch.save(chk, checkpoint_path) return model, metrics
def single_split_run_with_patience_stopping(args, split, output_dir=".", patience=3): test_dset = split['test'] if args.debug: test_dset.n = DEBUG_N test_loader = torch.utils.data.DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) all_train_metrics_per_cv = {} all_val_metrics_per_cv = {} all_test_metrics_per_cv = {} best_train_metrics_per_cv = {} best_val_metrics_per_cv = {} best_test_metrics_per_cv = {} for cv_idx, (dsets) in enumerate(split['splits']): split_tic = time.time() train_dset = dsets['train'] val_dset = dsets['val'] # Make dataloaders if args.debug: train_dset.n = DEBUG_N val_dset.n = DEBUG_N val_loader = torch.utils.data.DataLoader(val_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) train_loader = torch.utils.data.DataLoader( train_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) # New Model model = classifiers.versions[args.classifier_type]( args, loadable_state_dict=None) if args.cuda: model = model.cuda() if args.dataparallel: print("Using dataparallel") model = nn.DataParallel(model) # New Optimizer params = list(model.parameters()) optimizer = optim.Adam(params, lr=args.lr, betas=(0.5, 0.9), weight_decay=args.weight_decay) # New scheduler if classifiers.scheduled[args.classifier_type]: scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[5, 25], # Start at 0.01 -> 0.001, -> 0.0001 gamma=0.1, ) else: scheduler = None cv_output_dir = os.path.join(output_dir, "inner_split{}".format(cv_idx)) os.makedirs(cv_output_dir, exist_ok=True) cur_patience = patience all_train_metrics_per_cv[cv_idx] = [] all_val_metrics_per_cv[cv_idx] = [] all_test_metrics_per_cv[cv_idx] = [] best_train_metrics_per_cv[cv_idx] = None best_val_metrics_per_cv[cv_idx] = None best_test_metrics_per_cv[cv_idx] = None old_val_metrics = None for epoch in range(args.max_epochs): epoch_tic = time.time() if scheduler is not None: scheduler.step() train_metrics = train_one_epoch(args, model, optimizer, train_loader) val_metrics = evaluate(args, model, val_loader) test_metrics = evaluate(args, model, test_loader) improvement = (epoch < args.min_epochs) or ( old_val_metrics is None) or validation_improvement( val_metrics, old_val_metrics) if improvement: # If improvement - save model, reset patience # Save checkpoint. is_best = ( best_val_metrics_per_cv[cv_idx] is None ) or val_metrics['accuracy'] >= best_val_metrics_per_cv[ cv_idx]['accuracy'] if is_best: chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(cv_output_dir, "best.checkpoint")) best_train_metrics_per_cv[cv_idx] = train_metrics best_val_metrics_per_cv[cv_idx] = val_metrics best_test_metrics_per_cv[cv_idx] = test_metrics cur_patience = patience else: # If no improvement, drop patience cur_patience -= 1 if cur_patience == 0: break # We're done. # Update old_metrics old_val_metrics = val_metrics all_train_metrics_per_cv[cv_idx].append(train_metrics) all_val_metrics_per_cv[cv_idx].append(val_metrics) all_test_metrics_per_cv[cv_idx].append(test_metrics) print( "[Epoch {}/{}] train-loss={:.4f} val-acc={:.4f} test-acc={:.4f} time={:.2f}" .format( epoch, args.max_epochs, train_metrics["loss"], val_metrics["accuracy"], test_metrics["accuracy"], time.time() - epoch_tic, )) # Save checkpoint torch.save(chk, os.path.join(cv_output_dir, "last.checkpoint")) metrics = { 'train': all_train_metrics_per_cv[cv_idx], # list of dicts 'val': all_val_metrics_per_cv[cv_idx], # list of dicts 'test': all_test_metrics_per_cv[cv_idx], # list of dicts. } torch.save(metrics, os.path.join(cv_output_dir, "metrics.checkpoint")) print("[Inner {}/{}] took {:.2f}s, acc={:.4f}".format( cv_idx, len(split['splits']), time.time() - split_tic, best_test_metrics_per_cv[cv_idx]["accuracy"])) average_test_metrics = { metric_name: sum([ best_test_metrics_per_cv[cv_idx][metric_name] for cv_idx in all_test_metrics_per_cv.keys() ]) / len(split['splits']) for metric_name in best_test_metrics_per_cv[0].keys() } torch.save(average_test_metrics, os.path.join(output_dir, "average_best_metrics.checkpoint")) return all_train_metrics_per_cv, all_val_metrics_per_cv, all_test_metrics_per_cv, average_test_metrics
def run_single_split(args, split, output_dir="."): train_metrics_per_epoch = [] test_metrics_per_epoch = [] best_test_metrics = None test_dset = split['test'] train_dset = split['splits'][0]['train'] assert(len(split['splits']) == 1) # Just one split. if args.debug: test_dset.n = DEBUG_N train_dset.n = DEBUG_N test_loader = torch.utils.data.DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) model = autoencoders.AutoEncoder(args, loadable_state_dict=None) if args.cuda: model = model.cuda() if args.dataparallel: print("Using dataparallel") model = nn.DataParallel(model) # New Optimizer params = list(model.parameters()) optimizer = optim.Adam( params, lr=args.lr, betas=(0.5, 0.9), weight_decay=args.weight_decay ) # New scheduler scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[5, 25], # Start at 0.001 -> 0.0001, -> 0.00001 gamma=0.1, ) old_loss = None min_loss = None for epoch in range(args.max_epochs): epoch_tic = time.time() scheduler.step() train_metrics = train_one_epoch(args, model, optimizer, train_loader) # , tobreak=(epoch==(0))) train_metrics_per_epoch.append(train_metrics) test_metrics = evaluate(args, model, test_loader) test_metrics_per_epoch.append(test_metrics) is_best = (best_test_metrics is None) or test_metrics['mse'] <= best_test_metrics['mse'] if is_best: chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "best.checkpoint")) best_test_metrics = test_metrics chk = utils.make_checkpoint(model, optimizer, epoch) torch.save(chk, os.path.join(output_dir, "last.checkpoint")) print( "[Epoch {}/{}] train-mse={:.4f} train-l1={:.4f} test-mse={:.4f} test-l1={:.4f} time={:.2f}".format( epoch, args.max_epochs, train_metrics["mse"], train_metrics["l1"], test_metrics["mse"], test_metrics["l1"], time.time() - epoch_tic, ) ) torch.save(train_metrics_per_epoch, os.path.join(output_dir, "train_metrics_per_epoch.checkpoint")) torch.save(test_metrics_per_epoch, os.path.join(output_dir, "test_metrics_per_epoch.checkpoint")) torch.save(best_test_metrics, os.path.join(output_dir, "best_test_metrics.checkpoint")) return train_metrics_per_epoch, test_metrics_per_epoch, best_test_metrics
def run_single_split(args, split_by_study, epoch_size=10, output_dir="."): train_metrics_per_epoch = [] test_metrics_per_epoch = [] test_dset_by_study = {study: split['test'] for study, split in split_by_study.items()} train_dset_by_study = {study: split['splits'][0]['train'] for study, split in split_by_study.items()} if args.debug: for study in test_dset_by_study.keys(): test_dset_by_study[study].n = DEBUG_N train_dset_by_study[study].n = DEBUG_N if args.not_lazy: for study in train_dset_by_study.keys(): if study in ["archi", "la5c"]: print("Starting {} preloading".format(study)) tic = time.time() train_dset_by_study[study].preload(count=args.num_workers) test_dset_by_study[study].preload(count=args.num_workers) print("Preloading {} took {}s".format(study, time.time() - tic)) test_loaders_by_study = { study: torch.utils.data.DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) for study, test_dset in test_dset_by_study.items() } train_loaders_by_study = { study: torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) for study, train_dset in train_dset_by_study.items() } model = classifiers.versions[args.classifier_type](args, loadable_state_dict=None) if args.cuda: model = model.cuda() if args.dataparallel: print("Using dataparallel") model = nn.DataParallel(model) # New Optimizer params = list(model.parameters()) optimizer = optim.Adam( params, lr=args.lr, betas=(0.5, 0.9), weight_decay=args.weight_decay ) # New scheduler if classifiers.scheduled[args.classifier_type]: scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[5, 25], # Start at 0.001 -> 0.0001, -> 0.00001 gamma=0.1, ) else: scheduler = None for batch_idx, batch in enumerate(utils.multi_key_infinite_iter(train_loaders_by_study)): if (batch_idx) % args.epoch_size == 0: if batch_idx > 0: if scheduler is not None: scheduler.step() # Dump metrics, print time etc. for study in train_metrics.keys(): train_metrics[study]["loss"] /= args.epoch_size # approximation. train_metrics[study]["accuracy"] = sk_metrics.accuracy_score(ctrues[study], cpreds[study]) train_metrics[study]['precision'], train_metrics[study]['recall'], train_metrics[study]['f1'], train_metrics[study]['support'] = sk_metrics.precision_recall_fscore_support(ctrues[study], cpreds[study], labels=list(range(nclasses[study]))) train_metrics_per_epoch.append(train_metrics) test_metrics = evaluate(args, model, test_loaders_by_study) # import pdb; pdb.set_trace() test_metrics_per_epoch.append(test_metrics) chk = utils.make_checkpoint(model, optimizer, batch_idx) torch.save(chk, os.path.join(output_dir, "last.checkpoint")) print(" ".join( ["[{}/{}] t={:.2f}".format(batch_idx, args.max_batches, time.time() - tic)] + [ "[{} train-loss={:.4f} train-acc={:.4f} test-acc={:.4f}]".format( study, train_metrics_per_epoch[-1][study]["loss"], train_metrics_per_epoch[-1][study]["accuracy"], test_metrics[study]["accuracy"], ) for study in args.studies ] )) # Create new metrics train_metrics = { study: { 'loss': 0.0, } for study in args.studies } nclasses = {} cpreds = { study: [] for study in args.studies } ctrues = { study: [] for study in args.studies } tic = time.time() if batch_idx == args.max_batches: break # Do the training use batch. for study, (x, _, _, cvec) in batch.items(): nclasses[study] = len(args.meta['si2ci'][args.meta['s2i'][study]]) N = x.shape[0] offset = min(args.meta['si2ci'][args.meta['s2i'][study]]) cvec -= offset if args.cuda: x = x.cuda() cvec = cvec.cuda() study_vec = torch.tensor([args.meta['s2i'][study] for _ in range(N)], device=x.device).int() cpred = model(study_vec, x) loss = F.cross_entropy(cpred, cvec) optimizer.zero_grad() loss.backward() optimizer.step() train_metrics[study]['loss'] += N * loss.item() ctrues[study].extend(cvec.cpu().tolist()) cpreds[study].extend(torch.argmax(cpred.detach(), dim=1).cpu().tolist()) torch.save(train_metrics_per_epoch, os.path.join(output_dir, "train_metrics_per_epoch.checkpoint")) torch.save(test_metrics_per_epoch, os.path.join(output_dir, "test_metrics_per_epoch.checkpoint")) if args.not_lazy: for study in train_dset_by_study.keys(): if study in ["archi", "la5c"]: print("Starting unload") tic = time.time() train_dset_by_study[study].unload() test_dset_by_study[study].unload() print("Unloaded in {}s".format(time.time() - tic)) return train_metrics_per_epoch, test_metrics_per_epoch