Example #1
0
 def create_dataset(self):
     # create train and val dataloader
     for phase, dataset_opt in self.opt['datasets'].items():
         if phase == 'train':
             train_set = create_dataset(dataset_opt, split=phase)
             train_size = int(
                 math.ceil(len(train_set) / dataset_opt['batch_size']))
             self.logger.info(
                 'Number of train images: {:,d}, iters per epoch: {:,d}'.
                 format(len(train_set), train_size))
             total_iters = int(self.opt['niter'])
             total_epochs = int(math.ceil(total_iters / train_size))
             self.logger.info(
                 'Total epochs needed: {:d} for iters {:,d}'.format(
                     total_epochs, total_iters))
             self.train_loader = create_dataloader(train_set, dataset_opt)
         elif phase == 'val':
             val_set = create_dataset(dataset_opt, split=phase)
             self.val_loader = create_dataloader(val_set, dataset_opt)
             self.logger.info('Number of val images in [{:s}]: {:d}'.format(
                 dataset_opt['name'], len(val_set)))
         elif phase == 'mix':
             mix_set = create_dataset(dataset_opt, split=phase)
             self.mix_loader = create_dataloader(mix_set, dataset_opt)
             self.logger.info('Number of mix images in [{:s}]: {:d}'.format(
                 dataset_opt['name'], len(mix_set)))
         else:
             raise NotImplementedError(
                 'Phase [{:s}] is not recognized.'.format(phase))
     assert self.train_loader is not None
     # assert self.val_loader is not None
     self.total_epochs = total_epochs
     self.total_iters = total_iters
Example #2
0
def main():

    # parse the options
    opts = parse_args()

    # create the dataloaders
    dataloader = {'train': create_dataloader('train_valid' if opts.no_validation else 'train', opts),
                  'valid': create_dataloader('valid', opts)}
    
    # create the model 
    model = Prover(opts)
    model.to(opts.device)
  
    # crete the optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), lr=opts.learning_rate,
                                    momentum=opts.momentum,
                                    weight_decay=opts.l2)
    if opts.no_validation:
        scheduler = StepLR(optimizer, step_size=opts.lr_reduce_steps, gamma=0.1) 
    else:
        scheduler = ReduceLROnPlateau(optimizer, patience=opts.lr_reduce_patience, verbose=True)

    # load the checkpoint
    start_epoch = 0
    if opts.resume != None:
        log('loading model checkpoint from %s..' % opts.resume)
        if opts.device.type == 'cpu':
            checkpoint = torch.load(opts.resume, map_location='cpu')
        else:
            checkpoint = torch.load(opts.resume)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['n_epoch'] + 1
        model.to(opts.device)

    agent = Agent(model, optimizer, dataloader, opts)

    best_acc = -1.
    for n_epoch in range(start_epoch, start_epoch + opts.num_epochs):
        log('EPOCH #%d' % n_epoch)
   
        # training
        loss_train = agent.train(n_epoch)

        # save the model checkpoint
        if n_epoch % opts.save_model_epochs == 0:
            agent.save(n_epoch, opts.checkpoint_dir)

        # validation
        if not opts.no_validation:
            loss_valid = agent.valid(n_epoch)

        # reduce the learning rate
        if opts.no_validation:
            scheduler.step()
        else:
            scheduler.step(loss_valid)
Example #3
0
def entry_test(cfg, model_path):
    loader_test, _, _ = create_dataloader('test', **cfg.DATALOADER)

    model = get_model(cfg)
    model.to(DEVICE)
    load_checkpoint(model, model_path)

    print('\nTesting..')
    (acc_test, pre_test, rec_test, f1_test, acc_rel_test,
     acc_rel_avg_test) = validate(loader_test, model, DEVICE)
    print(f'Test at best valid, acc avg: {acc_rel_avg_test}, acc: {acc_test}, '
          f'pre: {pre_test}, rec: {rec_test}, f1: {f1_test}')
    print({x: round(y, 3) for x, y in acc_rel_test.items()})
Example #4
0
def generate_failure_log(cfg):
    loader_valid, _, _ = create_dataloader('valid', **cfg.DATALOADER)
    loader_test, _, _ = create_dataloader('test', **cfg.DATALOADER)

    model = get_model(cfg)
    model.to(DEVICE)
    load_best_checkpoint(model, cfg)
    print(model)

    incorr_samp_val, pred_val = get_incorr_samp(model, loader_valid)
    incorr_samp_test, pred_test = get_incorr_samp(model, loader_test)

    failure_log = {
        "valid": incorr_samp_val,
        "test": incorr_samp_test,
        "pred_val": pred_val,
        "pred_test": pred_test,
    }

    path = f"./runs/{cfg.EXP.EXP_ID}/failure_log.json"
    with open(path, 'w') as file:
        print(f"Saving the failure log in {path}")
        json.dump(failure_log, file)
	def create_dataset(self):
		# create train and val dataloader
		for phase, dataset_opt in self.opt['datasets'].items():
			if phase == 'train':
				self.train_set = create_dataset(dataset_opt, split=phase)
				dataset_opt['resample'] = self.resample
				dataset_opt['batch_size'] = max(dataset_opt['batch_size']//8, 1)
				aux_set1 = create_dataset(dataset_opt, split=phase)
				aux_set2 = create_dataset(dataset_opt, split=phase)
				aux_set3 = create_dataset(dataset_opt, split=phase)
				# self.train_set = ConcatDataset(
				# 	train_set,
				# 	aux_set1,
				# 	aux_set2,
				# 	aux_set3
				# ),

				train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
				self.logger.info('rank {}, Number of train images: {:,d}, iters per epoch: {:,d}'.format(self.rank,
				                                                                                         len(
					                                                                                         self.train_set),
				                                                                                         train_size))
				total_iters = int(self.opt['niter'])
				total_epochs = int(math.ceil(total_iters / train_size))
				self.logger.info('rank {}, Total epochs needed: {:d} for iters {:,d}'.format(self.rank,
				                                                                             total_epochs, total_iters))

				self.train_loader = create_dataloader(self.train_set, dataset_opt)
				aux_loader1 = create_dataloader(aux_set1, dataset_opt)
				aux_loader2 = create_dataloader(aux_set2, dataset_opt)
				aux_loader3 = create_dataloader(aux_set3, dataset_opt)
				self.train_iter = iter(self._cycle(self.train_loader))
				aux_iter1 = iter(self._cycle(aux_loader1))
				aux_iter2 = iter(self._cycle(aux_loader2))
				aux_iter3 = iter(self._cycle(aux_loader3))
				self.iters = [self.train_iter, aux_iter1, aux_iter2, aux_iter3]

			elif phase == 'val':
				val_set = create_dataset(dataset_opt, split=phase)
				self.val_loader = create_dataloader(val_set, dataset_opt)
				self.logger.info('rank {}, Number of val images in [{:s}]: {:d}'.format(self.rank, dataset_opt['name'],
				                                                                        len(val_set)))
			elif phase == 'test':
				test_set = create_dataset(dataset_opt, split=phase)
				self.test_loader = create_dataloader(test_set, dataset_opt)
				self.logger.info('rank {}, Number of test images in [{:s}]: {:d}'.format(self.rank, dataset_opt['name'],
				                                                                         len(test_set)))
			else:
				raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
		assert self.train_loader is not None
		# assert self.val_loader is not None
		self.total_epochs = total_epochs
		self.total_iters = total_iters
Example #6
0
def main():
    time1 = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
    fh = open('/home/lx/DRNet/experiment/spatialrel/rel' + time1 + '.txt',
              'w',
              encoding='utf-8')
    args = parse_args()
    if args.custom_on == 'on':
        dataloader_train = create_dataloader_depth(args.train_split, True,
                                                   args)
        dataloader_valid = create_dataloader_depth('valid', True, args)
        dataloader_test = create_dataloader_depth('test', True, args)
    else:
        dataloader_train = create_dataloader(args.train_split, True, args)
        dataloader_valid = create_dataloader('valid', True, args)
        dataloader_test = create_dataloader('test', True, args)
    print('%d batches of training examples' % len(dataloader_train))
    print('%d batches of validation examples' % len(dataloader_valid))
    print('%d batches of testing examples' % len(dataloader_test))

    phrase_encoder = RecurrentPhraseEncoder(300, 300)

    if args.depth_on == 'on':
        model = DRNet_depth(phrase_encoder, args.feature_dim, args.num_layers,
                            args)
    else:
        model = DRNet(phrase_encoder, args.feature_dim, args.num_layers, args)

    model.cuda()

    # criterion = nn.BCEWithLogitsLoss()
    criterion = nn.CrossEntropyLoss()
    criterion.cuda()

    optimizer = torch.optim.RMSprop(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.l2)
    if args.train_split == 'train':
        scheduler = ReduceLROnPlateau(optimizer, patience=4, verbose=True)
    else:
        scheduler = StepLR(optimizer, step_size=args.patience, gamma=0.1)

    start_epoch = 0
    if args.resume != None:
        print(' => loading model checkpoint from %s..' % args.resume)
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

    best_acc = -1.

    for epoch in range(start_epoch, start_epoch + args.n_epochs):
        print('epoch #%d' % epoch)

        print('training..')
        loss, acc = train(model, criterion, optimizer, dataloader_train, epoch,
                          args)
        print('\n\ttraining loss = %.4f' % loss)
        print('\ttraining accuracy = %.3f' % acc)
        checkpoint_filename = os.path.join(
            args.log_dir, 'checkpoints/model_%.3f_%02d.pth' % (acc, epoch))
        if epoch % 5 == 0:
            print('validating..')
            torch.cuda.synchronize()
            start = time.time()
            loss, acc = test('valid', model, criterion, dataloader_valid,
                             epoch, args)
            torch.cuda.synchronize()
            end = time.time()
            dtime = ((end - start) / len(dataloader_valid) / args.batchsize)
            print('\n\tvalidation loss = %.4f' % loss)
            print('\tvalidation accuracy = %.3f' % acc)
            print('\tvalidation time per input = %.3f' % dtime)

        model.cpu()
        torch.save(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'accuracy': acc,
                'optimizer': optimizer.state_dict(),
            }, checkpoint_filename)
        model.cuda()

        if args.train_split != 'train_valid' and best_acc < acc:
            best_acc = acc
            shutil.copyfile(
                checkpoint_filename,
                os.path.join(args.log_dir, 'checkpoints/model_best.pth'))
            shutil.copyfile(
                os.path.join(args.log_dir,
                             'predictions/pred_%02d.pickle' % epoch),
                os.path.join(args.log_dir, 'predictions/pred_best.pickle'))

        if args.train_split == 'train':
            scheduler.step(loss)
        else:
            scheduler.step()

    print('testing..')
    loss, acc = test('test', model, criterion, dataloader_test, None, args)
    print('\n\ttesting loss = %.4f' % loss)
    print('\ttesting accuracy = %.3f' % acc)
    fh.close()
Example #7
0
    def __init__(self,
                 lr: float = 0.0002,
                 batch_size: int = 1,
                 num_workers: int = 1):
        """
        Parameters
        lr: learning rate
        batch_size: batch size
        num_workers: the number of workers for train dataloader
        """

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Declare Generator and Discriminator
        self.netG_A2B = CycleGAN_Generator(input_nc=3, output_nc=3, ngf=64)
        self.netG_B2A = CycleGAN_Generator(input_nc=3, output_nc=3, ngf=64)
        self.netD_A = CycleGAN_Discriminator(input_nc=3, ndf=64)
        self.netD_B = CycleGAN_Discriminator(input_nc=3, ndf=64)

        self.weight_gan = 1.0
        self.weight_idt = 5.0
        self.weight_cycle = 10.0

        # Declare the Criterion for GAN loss
        # Doc for MSE Loss: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html
        # Doc for L1 Loss: https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html

        self.criterion_GAN: nn.Module = None
        self.criterion_cycle: nn.Module = None
        self.criterion_identity: nn.Module = None

        ### YOUR CODE HERE (~ 3 lines)
        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()
        ### END YOUR CODE

        self.optimizerG: optim.Optimizer = None
        self.optimizerD_A: optim.Optimizer = None
        self.optimizerD_B: optim.Optimizer = None

        # Declare the Optimizer for training
        # Doc for Adam optimizer: https://pytorch.org/docs/stable/optim.html#torch.optim.Adam

        ### YOUR CODE HERE (~ 3 lines) # TODO usefull : itertoos.chain()
        self.optimizerG = torch.optim.Adam(itertools.chain(
            self.netG_A2B.parameters(), self.netG_B2A.parameters()),
                                           lr=lr,
                                           betas=(0.5, 0.999))
        self.optimizerD_A = torch.optim.Adam(self.netD_A.parameters(),
                                             lr=lr,
                                             betas=(0.5, 0.999))
        self.optimizerD_B = torch.optim.Adam(self.netD_B.parameters(),
                                             lr=lr,
                                             betas=(0.5, 0.999))
        ### END YOUR CODE

        # Declare the DataLoader
        # You have to implement 'Summer2WinterDataset' in 'dataloader.py'
        # Note1: Use 'create_dataloader' function implemented in 'dataloader.py'

        ### YOUR CODE HERE (~ 1 line)
        self.trainloader, self.testloader = create_dataloader(
            'summer2winter', batch_size=batch_size, num_workers=num_workers)
        ### END YOUR CODE

        # Make directory
        os.makedirs('./results/cyclegan/images/', exist_ok=True)
        os.makedirs('./results/cyclegan/checkpoints/', exist_ok=True)
Example #8
0
def main():
    args = parse_args()

    dataloader_train = create_dataloader(args.train_split, True, args)
    dataloader_valid = create_dataloader("valid", True, args)
    dataloader_test = create_dataloader("test", True, args)
    print("%d batches of training examples" % len(dataloader_train))
    print("%d batches of validation examples" % len(dataloader_valid))
    print("%d batches of testing examples" % len(dataloader_test))

    phrase_encoder = RecurrentPhraseEncoder(300, 300)
    if args.model == "drnet":
        model = DRNet(phrase_encoder, args.feature_dim)
    elif args.model == "vtranse":
        model = VtransE(phrase_encoder, args.visual_feature_size,
                        args.predicate_embedding_dim)
    elif args.model == "vipcnn":
        model = VipCNN(roi_size=args.roi_size, backbone=args.backbone)
    else:
        model = PPRFCN(backbone=args.backbone)
    model.cuda()
    print(model)
    criterion = nn.BCEWithLogitsLoss()
    criterion.cuda()

    optimizer = torch.optim.RMSprop(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.l2,
    )
    if args.train_split == "train":
        scheduler = ReduceLROnPlateau(optimizer, patience=4, verbose=True)
    else:
        scheduler = StepLR(optimizer, step_size=args.patience, gamma=0.1)

    start_epoch = 0
    if args.resume != None:
        print(" => loading model checkpoint from %s.." % args.resume)
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["state_dict"])
        model.cuda()
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint["epoch"]

    best_acc = -1.0

    for epoch in range(start_epoch, start_epoch + args.n_epochs):
        print("epoch #%d" % epoch)

        print("training..")
        loss, acc = train(model, criterion, optimizer, dataloader_train, epoch,
                          args)
        print("\n\ttraining loss = %.4f" % loss)
        print("\ttraining accuracy = %.3f" % acc)

        if args.train_split != "train_valid":
            print("validating..")
            loss, accs = test("valid", model, criterion, dataloader_valid,
                              epoch, args)
            print("\n\tvalidation loss = %.4f" % loss)
            print("\tvalidation accuracy = %.3f" % accs["overall"])
            for predi in accs:
                if predi != "overall":
                    print("\t\t%s: %.3f" % (predi, accs[predi]))

        checkpoint_filename = os.path.join(
            args.log_dir, "checkpoints/model_%02d.pth" % epoch)
        model.cpu()
        torch.save(
            {
                "epoch": epoch + 1,
                "args": args,
                "state_dict": model.state_dict(),
                "accuracy": acc,
                "optimizer": optimizer.state_dict(),
            },
            checkpoint_filename,
        )
        model.cuda()

        if args.train_split != "train_valid" and best_acc < acc:
            best_acc = acc
            shutil.copyfile(
                checkpoint_filename,
                os.path.join(args.log_dir, "checkpoints/model_best.pth"),
            )
            shutil.copyfile(
                os.path.join(args.log_dir,
                             "predictions/pred_%02d.pickle" % epoch),
                os.path.join(args.log_dir, "predictions/pred_best.pickle"),
            )

        if args.train_split == "train":
            scheduler.step(loss)
        else:
            scheduler.step()

    print("testing..")
    loss, accs = test("test", model, criterion, dataloader_test, None, args)
    print("\n\ttesting loss = %.4f" % loss)
    print("\ttesting accuracy = %.3f" % accs["overall"])
    for predi in accs:
        if predi != "overall":
            print("\t\t%s: %.3f" % (predi, accs[predi]))
def train_val(cfg: DictConfig) -> None:

    # create dataloaders for training and validation
    loader_train, vocabs = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_train),
        "train",
        cfg.encoder,
        None,
        cfg.batch_size,
        cfg.num_workers,
    )
    assert vocabs is not None
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # create the model
    model = Parser(vocabs, cfg)
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % count_params(model))

    # create the optimizer
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
    )
    start_epoch = 0
    if cfg.resume is not None:  # resume training from a checkpoint
        checkpoint = load_model(cfg.resume)
        model.load_state_dict(checkpoint["model_state"])
        start_epoch = checkpoint["epoch"] + 1
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        del checkpoint
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=cfg.learning_rate_patience,
        cooldown=cfg.learning_rate_cooldown,
        verbose=True,
    )

    # start training and validation
    best_f1_score = -1.0
    num_iters = 0

    for epoch in range(start_epoch, cfg.num_epochs):
        log.info("Epoch #%d" % epoch)

        if not cfg.skip_training:
            log.info("Training..")
            num_iters, accuracy_train, loss_train = train(
                num_iters,
                loader_train,
                model,
                optimizer,
                vocabs["label"],
                cfg,
            )
            log.info("Action accuracy: %.03f, Loss: %.03f" %
                     (accuracy_train, loss_train))

        log.info("Validating..")
        f1_score_val = validate(loader_val, model, cfg)

        log.info(
            "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
            % (
                f1_score_val.fscore,
                f1_score_val.complete_match,
                f1_score_val.precision,
                f1_score_val.recall,
            ))

        if f1_score_val.fscore > best_f1_score:
            log.info("F1 score has improved")
            best_f1_score = f1_score_val.fscore

        scheduler.step(best_f1_score)

        save_checkpoint(
            "model_latest.pth",
            epoch,
            model,
            optimizer,
            f1_score_val.fscore,
            vocabs,
            cfg,
        )
Example #10
0
def entry_train(cfg, record_file=""):
    loader_train, _, _ = create_dataloader(split='train', **cfg.DATALOADER)
    loader_valid, _, _ = create_dataloader('valid', **cfg.DATALOADER)
    loader_test, _, _ = create_dataloader('test', **cfg.DATALOADER)

    model = get_model(cfg)
    model.to(DEVICE)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.TRAIN.learning_rate,
                                 weight_decay=cfg.TRAIN.l2)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.5,
                                  patience=10,
                                  verbose=True)

    best_acc_rel_avg_valid = -1
    best_epoch_rel_avg_valid = 0
    best_acc_rel_avg_test = -1

    log_dir = f"./runs/{cfg.EXP.EXP_ID}"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    tb = TensorboardManager(log_dir)
    for epoch in range(cfg.TRAIN.num_epochs):
        print('\nEpoch #%d' % epoch)

        print('Training..')
        (acc_train, pre_train, rec_train, f1_train, acc_rel_train,
         acc_rel_avg_train) = train(loader_train, model, optimizer,
                                    DEVICE, cfg.TRAIN.weighted_loss)
        print(f'Train, acc avg: {acc_rel_avg_train} acc: {acc_train},'
              f' pre: {pre_train}, rec: {rec_train}, f1: {f1_train}')
        print({x: round(y, 3) for x, y in acc_rel_train.items()})
        tb.update('train', epoch, {'acc': acc_train})

        print('\nValidating..')
        (acc_valid, pre_valid, rec_valid, f1_valid, acc_rel_valid,
         acc_rel_avg_valid) = validate(loader_valid, model, DEVICE)
        print(f'Valid, acc avg: {acc_rel_avg_valid} acc: {acc_valid},'
              f' pre: {pre_valid}, rec: {rec_valid}, f1: {f1_valid}')
        print({x: round(y, 3) for x, y in acc_rel_valid.items()})
        tb.update('val', epoch, {'acc': acc_valid})

        print('\nTesting..')
        (acc_test, pre_test, rec_test, f1_test, acc_rel_test,
         acc_rel_avg_test) = validate(loader_test, model, DEVICE)
        print(f'Test, acc avg: {acc_rel_avg_test} acc: {acc_test},'
              f' pre: {pre_test}, rec: {rec_test}, f1: {f1_test}')
        print({x: round(y, 3) for x, y in acc_rel_test.items()})

        if acc_rel_avg_valid > best_acc_rel_avg_valid:
            print('Accuracy has improved')
            best_acc_rel_avg_valid = acc_rel_avg_valid
            best_epoch_rel_avg_valid = epoch

            save_checkpoint(epoch, model, optimizer, acc_rel_avg_valid, cfg)
        if acc_rel_avg_test > best_acc_rel_avg_test:
            best_acc_rel_avg_test = acc_rel_avg_test

        if (epoch - best_epoch_rel_avg_valid) > cfg.TRAIN.early_stop:
            print(f"Early stopping at {epoch} as val acc did not improve"
                  f" for {cfg.TRAIN.early_stop} epochs.")
            break

        scheduler.step(acc_train)

    print('\nTesting..')
    load_best_checkpoint(model, cfg)
    (acc_test, pre_test, rec_test, f1_test, acc_rel_test,
     acc_rel_avg_test) = validate(loader_test, model, DEVICE)
    print(f'Best valid, acc: {best_acc_rel_avg_valid}')
    print(f'Best test, acc: {best_acc_rel_avg_test}')
    print(f'Test at best valid, acc avg: {acc_rel_avg_test}, acc: {acc_test},'
          f' pre: {pre_test}, rec: {rec_test}, f1: {f1_test}')
    print({x: round(y, 3) for x, y in acc_rel_test.items()})

    if record_file != "":
        exp = RecordExp(record_file)
        exp.record_param(flatten_dict(dict(cfg)))
        exp.record_result({
            "final_train": acc_rel_avg_train,
            "best_val": best_acc_rel_avg_valid,
            "best_test": best_acc_rel_avg_test,
            "final_test": acc_rel_avg_test
        })
Example #11
0
def main():
    args = parse_args()

    dataloader_train = create_dataloader(args.train_split, False, args)
    dataloader_valid = create_dataloader('valid', False, args)
    dataloader_test = create_dataloader('test', False, args)
    print('%d batches of training examples' % len(dataloader_train))
    print('%d batches of validation examples' % len(dataloader_valid))
    print('%d batches of testing examples' % len(dataloader_test))

    phrase_encoder = RecurrentPhraseEncoder(300, 300)
    model = SimpleLanguageOnlyModel(phrase_encoder, args.feature_dim, 9)
    criterion = nn.BCEWithLogitsLoss()
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.l2)
    if args.train_split == 'train_valid':
        scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
    else:
        scheduler = ReduceLROnPlateau(optimizer, patience=5, verbose=True)

    best_acc = -1.

    for epoch in range(args.n_epochs):
        print('epoch #%d' % epoch)

        print('training..')
        loss, acc = train(model, criterion, optimizer, dataloader_train, epoch,
                          args)
        print('\n\ttraining loss = %.4f' % loss)
        print('\ttraining accuracy = %.3f' % acc)

        print('validating..')
        loss, accs = test('valid', model, criterion, dataloader_valid, epoch,
                          args)
        print('\n\tvalidation loss = %.4f' % loss)
        print('\tvalidation accuracy = %.3f' % accs['overall'])
        for predi in accs:
            if predi != 'overall':
                print('\t\t%s: %.3f' % (predi, accs[predi]))

        checkpoint_filename = os.path.join(
            args.log_dir, 'checkpoints/model_%02d.pth' % epoch)
        model.cpu()
        torch.save(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'accuracy': acc,
                'optimizer': optimizer.state_dict(),
            }, checkpoint_filename)
        if torch.cuda.is_available():
            model.cuda()
        if best_acc < acc:
            best_acc = acc
            shutil.copyfile(
                checkpoint_filename,
                os.path.join(args.log_dir, 'checkpoints/model_best.pth'))
            shutil.copyfile(
                os.path.join(args.log_dir,
                             'predictions/pred_%02d.pickle' % epoch),
                os.path.join(args.log_dir, 'predictions/pred_best.pickle'))

        if args.train_split == 'train_valid':
            scheduler.step()
        else:
            scheduler.step(loss)

    print('testing..')
    loss, accs = test('test', model, criterion, dataloader_test, epoch, args)
    print('\n\ttesting loss = %.4f' % loss)
    print('\ttesting accuracy = %.3f' % accs['overall'])
    for predi in accs:
        if predi != 'overall':
            print('\t\t%s: %.3f' % (predi, accs[predi]))
Example #12
0
	def create_dataset(self):
		# create train and val dataloader
		for phase, dataset_opt in self.opt['datasets'].items():
			if phase == 'train':
				if self.opt['varyOnCV']:
					train_set = create_dataset(dataset_opt, split=phase)
					nfold = self.num_tasks
					# nfold = 5
					folds = cv_split(train_set, nfold, self.comm)
					self.folds_loaders = [create_dataloader(f, dataset_opt) for f in folds]
					self.train_set = folds.pop(dataset_opt['fold']-1)
					self.logger.info("split into {} folds, currently in fold {}".format(nfold, dataset_opt['fold']))
					# self.val_set = val_fold
					if self.opt['varyOnSample']:
						self.train_set = ResampleDataset(self.train_set)
				else:
					self.train_set = create_dataset(dataset_opt, split=phase)
					# self.opt['varyOnSample'] = True
					if self.opt['varyOnSample']:
						self.train_set = ResampleDataset(self.train_set)

					# self.opt['create_val'] = True
					if self.opt['create_val']:
						# task0 for 0.1, else for random in [0, 0.3]
						ratio = 0.1
						# if self.task_id == 0:
						# 	ratio = 0.1
						# else:
						# 	ratio = np.random.choice([0.1,0.2,0.3])
						self.train_set, val_set = train_val_split(self.train_set, ratio, comm=None)	#self.comm
						# val_folds = self.comm.allgather(val_set)
						# self.logger.info([vf[0] for vf in val_folds])	# test if val_folds in all ranks are the same
						# self.folds_loaders = [create_dataloader(f, dataset_opt) for f in val_folds]
						self.val_loader = create_dataloader(val_set, dataset_opt)
						self.logger.info(
							'rank {}, Number of val images in [{:s}]: {:d}'.format(self.rank, dataset_opt['name'],
							                                                       len(val_set)))

				# self.opt['varyOnSample'] = True
				self.train_loader = create_dataloader(self.train_set, dataset_opt)
				self.train_iter = iter(self._cycle(self.train_loader))
				train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
				self.logger.info('rank {}, Number of train images: {:,d}, iters per epoch: {:,d}'.format(self.rank,
					len(self.train_set), train_size))
				total_iters = int(self.opt['niter'])
				total_epochs = int(math.ceil(total_iters / train_size))
				self.logger.info('rank {}, Total epochs needed: {:d} for iters {:,d}'.format(self.rank,
					total_epochs, total_iters))
				self.total_epochs = total_epochs
				self.total_iters = total_iters

			elif phase == 'val' and not self.opt['varyOnCV'] and not self.opt['create_val']:
				val_set = create_dataset(dataset_opt, split=phase)
				self.val_loader = create_dataloader(val_set, dataset_opt)
				self.logger.info('rank {}, Number of val images in [{:s}]: {:d}'.format(self.rank, dataset_opt['name'], len(val_set)))
			elif phase == 'test':
				test_set = create_dataset(dataset_opt, split=phase)
				self.test_loader = create_dataloader(test_set, dataset_opt)
				self.logger.info('rank {}, Number of test images in [{:s}]: {:d}'.format(self.rank, dataset_opt['name'], len(test_set)))
			else:
				raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
		assert self.train_loader is not None
def main(cfg: DictConfig) -> None:
    "The entry point for testing"

    assert cfg.model_path is not None, "Need to specify model_path for testing."
    log.info("\n" + OmegaConf.to_yaml(cfg))

    # restore the hyperparameters used for training
    model_path = hydra.utils.to_absolute_path(cfg.model_path)
    log.info("Loading the model from %s" % model_path)
    checkpoint = load_model(model_path)
    restore_hyperparams(checkpoint["cfg"], cfg)

    # create dataloaders for validation and testing
    vocabs = checkpoint["vocabs"]
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )
    loader_test, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_test),
        "test",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # restore the trained model checkpoint
    model = Parser(vocabs, cfg)
    model.load_state_dict(checkpoint["model_state"])
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()]))

    # validation
    log.info("Validating..")
    f1_score = validate(loader_val, model, cfg)
    log.info(
        "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))

    # testing
    log.info("Testing..")
    if cfg.beam_size > 1:
        log.info("Performing beam search..")
        f1_score = beam_search(loader_test, model, cfg)
    else:
        log.info("Running without beam search..")
        f1_score = validate(loader_test, model, cfg)
    log.info(
        "Testing F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))
Example #14
0
    def __init__(self,
                 type: str = 'gan',
                 lr: float = 0.0002,
                 batch_size: int = 64,
                 num_workers: int = 1,
                 device=None):
        """
        Parameters
        type: gan loss type: 'gan' or 'lsgan' or 'wgan' or 'wgan-gp'
        lr: learning rate
        batch_size: batch size
        num_workers: the number of workers for train dataloader
        """

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Declare Generator and Discriminator
        self.type = type
        self.netG = DCGAN_Generator()
        self.netD = DCGAN_Discriminator(type=type)

        # Declare the Criterion for GAN loss
        # Doc for Binary Cross Entropy Loss: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
        # Doc for MSE Loss: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html
        # Note1: Implement 'GPLoss' function before using WGAN-GP loss.
        # Note2: It is okay not to implement the criterion for WGAN.

        self.criterion: nn.Module = None
        self.D_weight = 1
        ### YOUR CODE HERE (~ 8 lines)
        if self.type == 'gan':
            self.criterion = nn.BCELoss()
        elif self.type == 'lsgan':
            self.criterion = nn.MSELoss()
            self.D_weight = 0.5
        elif self.type == 'wgan':
            self.n_critic = 5
        elif self.type == 'wgan-gp':
            self.lambda_term = 10
            self.n_critic = 5
            self.GP_loss = GPLoss(self.device)
        else:
            raise NotImplementedError
        ### END YOUR CODE

        # Declare the Optimizer for training
        # Doc for Adam optimizer: https://pytorch.org/docs/stable/optim.html#torch.optim.Adam

        self.optimizerG: optim.Optimizer = None
        self.optimizerD: optim.Optimizer = None

        ### YOUR CODE HERE (~ 2 lines)
        self.optimizerG = torch.optim.Adam(self.netG.parameters(),
                                           lr=lr,
                                           betas=(0.5, 0.999))
        self.optimizerD = torch.optim.Adam(self.netD.parameters(),
                                           lr=lr,
                                           betas=(0.5, 0.999))
        ### END YOUR CODE

        # Declare the DataLoader
        # Note1: Use 'create_dataloader' function implemented in 'dataloader.py'

        ### YOUR CODE HERE (~ 1 lines)
        self.trainloader, self.testloader = create_dataloader(
            'cifar10', batch_size=batch_size, num_workers=num_workers)
        ### END YOUR CODE

        # Make directory
        os.makedirs(os.path.join('./results/', self.type, 'images'),
                    exist_ok=True)
        os.makedirs(os.path.join('./results/', self.type, 'checkpoints'),
                    exist_ok=True)
Example #15
0
def main():
    args = parse_args()

    dataloader_train = create_dataloader(args.train_split, True, args)
    dataloader_valid = create_dataloader('valid', True, args)
    dataloader_test = create_dataloader('test', True, args)
    print('%d batches of training examples' % len(dataloader_train))
    print('%d batches of validation examples' % len(dataloader_valid))
    print('%d batches of testing examples' % len(dataloader_test))

    phrase_encoder = RecurrentPhraseEncoder(300, 300)
    if args.model == 'drnet':
        model = DRNet(phrase_encoder, args.feature_dim, args.pretrained)
    elif args.model == 'vtranse':
        model = VtransE(phrase_encoder, args.visual_feature_size, args.predicate_embedding_dim)
    elif args.model == 'vipcnn':
        model = VipCNN(roi_size=args.roi_size, backbone=args.backbone)
    else:
        model = PPRFCN(backbone=args.backbone)
    model.cuda()
    print(model)
    criterion = nn.BCEWithLogitsLoss()
    criterion.cuda()

    optimizer = torch.optim.RMSprop([p for p in model.parameters() if p.requires_grad], lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.l2)
    if args.train_split == 'train':
        scheduler = ReduceLROnPlateau(optimizer, patience=4, verbose=True)
    else:
        scheduler = StepLR(optimizer, step_size=args.patience, gamma=0.1) 

    start_epoch = 0
    if args.resume != None:
        print(' => loading model checkpoint from %s..' % args.resume)
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

    best_acc = -1.

    for epoch in range(start_epoch, start_epoch + args.n_epochs):
        print('epoch #%d' % epoch)
   
        print('training..')
        loss, acc = train(model, criterion, optimizer, dataloader_train, epoch, args)
        print('\n\ttraining loss = %.4f' % loss)
        print('\ttraining accuracy = %.3f' % acc)

        if args.train_split != 'train_valid':
            print('validating..')
            loss, accs = test('valid', model, criterion, dataloader_valid, epoch, args)
            print('\n\tvalidation loss = %.4f' % loss)
            print('\tvalidation accuracy = %.3f' % accs['overall'])
            for predi in accs:
                if predi != 'overall':
                    print('\t\t%s: %.3f' % (predi, accs[predi]))

        checkpoint_filename = os.path.join(args.log_dir, 'checkpoints/model_%02d.pth' % epoch)
        model.cpu()
        torch.save({'epoch': epoch + 1,
                    'args': args,
                    'state_dict': model.state_dict(),
                    'accuracy': acc,
                    'optimizer' : optimizer.state_dict(),
                   }, checkpoint_filename)
        model.cuda()
     
        if args.train_split != 'train_valid' and best_acc < acc:
            best_acc = acc
            shutil.copyfile(checkpoint_filename, os.path.join(args.log_dir, 'checkpoints/model_best.pth'))
            shutil.copyfile(os.path.join(args.log_dir, 'predictions/pred_%02d.pickle' % epoch), 
                            os.path.join(args.log_dir, 'predictions/pred_best.pickle'))

        if args.train_split == 'train':
            scheduler.step(loss)
        else:
            scheduler.step()

    print('testing..')
    loss, accs = test('test', model, criterion, dataloader_test, None, args)
    print('\n\ttesting loss = %.4f' % loss)
    print('\ttesting accuracy = %.3f' % accs['overall'])
    for predi in accs:
        if predi != 'overall':
            print('\t\t%s: %.3f' % (predi, accs[predi]))
Example #16
0
def main():
    args = parse_args()

    dataloader_train = create_dataloader(args.train_split, False, args)
    dataloader_valid = create_dataloader("valid", False, args)
    dataloader_test = create_dataloader("test", False, args)
    print("%d batches of training examples" % len(dataloader_train))
    print("%d batches of validation examples" % len(dataloader_valid))
    print("%d batches of testing examples" % len(dataloader_test))

    model = SimpleSpatialModel(4, args.feature_dim, 9)
    print(model)
    criterion = nn.BCEWithLogitsLoss()
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.l2,
    )
    if args.train_split == "train_valid":
        scheduler = StepLR(optimizer, step_size=21, gamma=0.1)
    else:
        scheduler = ReduceLROnPlateau(optimizer, patience=5, verbose=True)
    best_acc = -1.0

    for epoch in range(args.n_epochs):
        print("epoch #%d" % epoch)

        print("training..")
        loss, acc = train(model, criterion, optimizer, dataloader_train, epoch, args)
        print("\n\ttraining loss = %.4f" % loss)
        print("\ttraining accuracy = %.3f" % acc)

        print("validating..")
        loss, accs = test("valid", model, criterion, dataloader_valid, epoch, args)
        print("\n\tvalidation loss = %.4f" % loss)
        print("\tvalidation accuracy = %.3f" % accs["overall"])
        for predi in accs:
            if predi != "overall":
                print("\t\t%s: %.3f" % (predi, accs[predi]))

        checkpoint_filename = os.path.join(
            args.log_dir, "checkpoints/model_%02d.pth" % epoch
        )
        model.cpu()
        torch.save(
            {
                "epoch": epoch + 1,
                "args": args,
                "state_dict": model.state_dict(),
                "accuracy": acc,
                "optimizer": optimizer.state_dict(),
            },
            checkpoint_filename,
        )
        if torch.cuda.is_available():
            model.cuda()

        if best_acc < acc:
            best_acc = acc
            shutil.copyfile(
                checkpoint_filename,
                os.path.join(args.log_dir, "checkpoints/model_best.pth"),
            )
            shutil.copyfile(
                os.path.join(args.log_dir, "predictions/pred_%02d.pickle" % epoch),
                os.path.join(args.log_dir, "predictions/pred_best.pickle"),
            )

        if args.train_split == "train_valid":
            scheduler.step()
        else:
            scheduler.step(loss)

    print("testing..")
    _, accs = test("test", model, criterion, dataloader_test, epoch, args)
    print("\ttesting accuracies = %.3f" % accs["overall"])
    for predi in accs:
        if predi != "overall":
            print("\t\t%s: %.3f" % (predi, accs[predi]))
Example #17
0
    tokenizer = AutoTokenizer.from_pretrained(
        'allenai/scibert_scivocab_uncased')

    batch_size = int(HYPERPARAMS["BATCH_SIZE"])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_sentences, train_labels = prepare_data(
        input_dir=PATHS["TRAIN_DATA_PATH"], oversample=True)

    trial_sentences, trial_labels = prepare_data(
        input_dir=PATHS["VALIDATION_DATA_PATH"], oversample=True)

    train_sentences = train_sentences + trial_sentences
    train_labels = train_labels + trial_labels

    train_dataloader = create_dataloader(train_sentences, train_labels,
                                         tokenizer)

    test_sentences, test_labels = prepare_data(
        input_dir=PATHS["TEST_DATA_PATH"], oversample=False)

    test_dataloader = create_dataloader(test_sentences, test_labels, tokenizer)

    if args.model == "baseline":
        classifier, time_taken = train_baseline(train_dataloader)
        mcc, report = predict_labels_baseline(classifier, test_dataloader)

        print(f"Time Taken is {time_taken} seconds")
        print(f"MCC Score is {mcc}")
        print(report)

    if args.model == "bert-linear":
Example #18
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        self.augmentation = args.data_augmentation
        self.device = torch.device('cuda' if len(args.gpu_ids) != 0 else 'cpu')
        args.device = self.device

        ## init dataloader
        if args.phase == 'train':
            self.train_dataset = TrainSet(self.args)
            if args.dist:
                dataset_ratio = 1
                train_sampler = DistIterSampler(self.train_dataset,
                                                args.world_size, args.rank,
                                                dataset_ratio)
                self.train_dataloader = create_dataloader(
                    self.train_dataset, args, train_sampler)
            else:
                self.train_dataloader = DataLoader(
                    self.train_dataset,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    shuffle=True)

        testset_ = getattr(importlib.import_module('dataloader.dataset'),
                           args.testset, None)
        self.test_dataset = testset_(self.args)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=1,
                                          num_workers=args.num_workers,
                                          shuffle=False)

        ## init network
        self.net = define_G(args)
        if args.resume:
            self.load_networks('net', self.args.resume)

        if args.rank <= 0:
            logging.info('----- generator parameters: %f -----' %
                         (sum(param.numel()
                              for param in self.net.parameters()) / (10**6)))

        ## init loss and optimizer
        if args.phase == 'train':
            if args.rank <= 0:
                logging.info('init criterion and optimizer...')
            g_params = [self.net.parameters()]

            self.criterion_mse = nn.MSELoss().to(self.device)
            if args.loss_mse:
                self.criterion_mse = nn.MSELoss().to(self.device)
                self.lambda_mse = args.lambda_mse
                if args.rank <= 0:
                    logging.info('  using mse loss...')

            if args.loss_l1:
                self.criterion_l1 = nn.L1Loss().to(self.device)
                self.lambda_l1 = args.lambda_l1
                if args.rank <= 0:
                    logging.info('  using l1 loss...')

            if args.loss_adv:
                self.criterion_adv = AdversarialLoss(gpu_ids=args.gpu_ids,
                                                     dist=args.dist,
                                                     gan_type=args.gan_type,
                                                     gan_k=1,
                                                     lr_dis=args.lr_D,
                                                     train_crop_size=40)
                self.lambda_adv = args.lambda_adv
                if args.rank <= 0:
                    logging.info('  using adv loss...')

            if args.loss_perceptual:
                self.criterion_perceptual = PerceptualLoss(layer_weights={
                    'conv5_4': 1.
                }).to(self.device)
                self.lambda_perceptual = args.lambda_perceptual
                if args.rank <= 0:
                    logging.info('  using perceptual loss...')

            self.optimizer_G = torch.optim.Adam(
                itertools.chain.from_iterable(g_params),
                lr=args.lr,
                weight_decay=args.weight_decay)
            self.scheduler = CosineAnnealingLR(
                self.optimizer_G, T_max=500)  # T_max=args.max_iter

            if args.resume_optim:
                self.load_networks('optimizer_G', self.args.resume_optim)
            if args.resume_scheduler:
                self.load_networks('scheduler', self.args.resume_scheduler)
videos = np.array(list(set(videos) - set(test_videos)))
kfold = KFold(5, True, 1)
splits = kfold.split(videos)
kfold_valid_acc = []
kfold_test_acc = []

# Removing test videos from train dataset
test_df = df[df['filename'].isin(test_videos)]
for (i, (train, test)) in enumerate(splits):
    print('%d-th split: train: %d, test: %d' % (i+1, len(videos[train]), len(videos[test])))
    train_df = df[df['filename'].isin(videos[train])]
    valid_df = df[df['filename'].isin(videos[test])]

    X_train, y_train = create_timeseries(train_df)
    X_test, y_test = create_timeseries(test_df)
    train_dataloader = create_dataloader(X_train, y_train, batch_size)
    valid_dataloader = create_dataloader(X_test, y_test, batch_size)
    net = LSTM_ContemptNet()
    print(net)
    if torch.cuda.is_available():
        net.cuda()
    criterion = nn.CrossEntropyLoss()
    # TODO: if result are not good, change to RMSprop
    optimizer = optim.ASGD(net.parameters(), lr=learning_rate)

    train_losses = []
    valid_losses = []
    valid_acc = []
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
Example #20
0
loss_func = nn.CrossEntropyLoss().cuda()

opt = {
    'name': 'RSSCN7',
    'lmdb': True,
    'resample': False,
    'dataroot': '../../data/RSSCN7/',
    'mode': 'file',
    'batch_size': 64,
    "use_shuffle": True,
    "n_workers": 0,
    "num_classes": 7
}
train_set = create_dataset(opt, train=True)
train_loader = create_dataloader(train_set, opt)
# pre_optimizer = optim.SGD([{'params': [param for name, param in network.named_parameters() if 'fc' not in name]}, {'params': network.fc.parameters(), 'lr': 1}], lr=1, momentum=0.9)
#
# new_dict = pre_optimizer.state_dict()
# self_dict = optimizer.state_dict()
# self_dict['param_groups'][0].update(new_dict['param_groups'][0])
#
# # self_dict.update(new_dict)
# optimizer.load_state_dict(self_dict)
# optimizer_dict = weights_replace(optimizer.state_dict(), pre_optimizer.state_dict())

# schedulers = []
# schedulers.append(lr_scheduler.MultiStepLR(optimizer, [30, 80], 0.1))
multistep = [30, 80]
lambda1 = lambda step: 0.1**sum([step >= mst for mst in multistep])
arxiv_scheduler = optim.lr_scheduler.LambdaLR(optimizer,
Example #21
0
def train():
    config = utils.get_config(args.config)
    out_dir = os.path.join(
        config['out_dir'], config['model'],
        time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
    try:
        os.makedirs(out_dir)
    except IOError:
        pass
    torch.save({"config": config}, os.path.join(out_dir, "run_config.d"))
    log_file = os.path.join(out_dir, "logging.txt")
    logger = utils.gen_logger(log_file)

    logger.info("<==== Experiments for SSL ====>")
    logger.info("output directory: {}".format(out_dir))
    for k, v in config.items():
        logger.info("{}: {}".format(k, v))

    train_ids, dev_ids = utils.get_index(config['audio_h5'], debug=args.debug)

    process_fn = utils.process_fn(**config['audio_args'])
    trainloader = create_dataloader(config['audio_h5'],
                                    config['ref_h5'],
                                    process_fn,
                                    index=train_ids,
                                    **config['trainloader_args'])
    devloader = create_dataloader(config['audio_h5'],
                                  config['ref_h5'],
                                  process_fn,
                                  index=dev_ids,
                                  **config['devloader_args'])

    model = getattr(MODEL, config['model'])(**config['model_args'])
    model = model.to(device)

    optimizer = getattr(torch.optim,
                        config['optimizer'])(model.parameters(),
                                             **config['optimizer_args'])

    scheduler = getattr(torch.optim.lr_scheduler,
                        config['scheduler'])(optimizer,
                                             **config['scheduler_args'])

    criterion = getattr(losses,
                        config['criterion'])(**config['criterion_args'])

    def _train(trainer, batch):
        model.train()
        with torch.enable_grad():
            feats, ref_feats, indices = batch
            feats, ref_feats, indices = convert_tensor(feats, device),\
                convert_tensor(ref_feats, device), convert_tensor(indices, device)
            score, mask = model(feats, ref_feats, indices)
            loss = criterion(score, mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return loss.cpu().item()

    trainer = Engine(_train)

    def _evaluate(evaluator, batch):
        model.eval()
        with torch.no_grad():
            feats, ref_feats, indices = batch
            feats, ref_feats, indices = convert_tensor(feats, device),\
                convert_tensor(ref_feats, device), convert_tensor(indices, device)
            score, mask = model(feats, ref_feats, indices)
            loss = criterion(score, mask)
        return loss.cpu().item()

    evaluator = Engine(_evaluate)

    pbar = ProgressBar(ncols=75)
    pbar.attach(trainer, output_transform=lambda x: {'loss': x})
    Average().attach(evaluator, 'Loss')
    Average().attach(trainer, 'Loss')

    @trainer.on(Events.STARTED)
    def eval_scratch(trainer):
        evaluator.run(devloader)
        eval_metric = evaluator.state.metrics['Loss']
        logger.info('Loss before training: {:<5.2f}'.format(eval_metric))

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(trainer):
        AvgLoss = trainer.state.metrics['Loss']
        n_epoch = trainer.state.epoch
        evaluator.run(devloader)
        eval_metric = evaluator.state.metrics['Loss']
        logger.info("<=== #{:<3} Epoch ===>".format(n_epoch))
        logger.info('Training Loss: {:<5.2f}'.format(AvgLoss))
        logger.info('Evaluation Loss: {:<5.2f}'.format(eval_metric))

        if model.ifhard:
            scheduler.step(eval_metric)

    earlystopping_handler = EarlyStopping(
        patience=config['patience'],
        trainer=trainer,
        score_function=lambda engine: -engine.state.metrics['Loss'])

    best_checkpoint_handler = ModelCheckpoint(
        dirname=out_dir,
        filename_prefix='eval_best',
        score_function=lambda engine: -engine.state.metrics['Loss'],
        score_name='loss',
        n_saved=1,
        global_step_transform=global_step_from_engine(trainer))

    periodic_checkpoint_handler = ModelCheckpoint(
        dirname=out_dir,
        filename_prefix='train_periodic',
        score_function=lambda engine: -engine.state.metrics['Loss'],
        score_name='loss',
        n_saved=None,
        global_step_transform=global_step_from_engine(trainer))

    @trainer.on(Events.EPOCH_COMPLETED(once=2))
    def add_handler(trainer):
        evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                    earlystopping_handler)
        evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                    best_checkpoint_handler, {"model": model})

    # evaluator.add_event_handler(
    #     Events.EPOCH_COMPLETED, earlystopping_handler)
    # evaluator.add_event_handler(
    #     Events.EPOCH_COMPLETED, best_checkpoint_handler, {"model": model})

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config['save_interval']),
        periodic_checkpoint_handler, {"model": model})

    trainer.run(trainloader, config['n_epochs'])