def load_dataset_inputs(prefix):
    input_f = os.path.join(args.input_dir, '{}_inputs.h5'.format(prefix))
    return input_f


def main():
    config = json.load(open(args.config_file))
    assert config['config_target'] == 'naive_psychology'
    set_seed(
        args.gpu_id, args.seed
    )  # in distributed training, this has to be same for all processes

    if args.mode == 'train':
        train_dinputs = load_dataset_inputs('train')
        dev_dinputs = load_dataset_inputs('dev')
        train(train_dinputs, dev_dinputs)

        # test
        test_dinputs = load_dataset_inputs('test')
        test(test_dinputs, args.output_dir)
    else:  # test
        test_dinputs = load_dataset_inputs('test')
        test(test_dinputs, args.from_checkpoint)


if __name__ == "__main__":
    args = utils.bin_config(get_arguments)
    logger = utils.get_root_logger(args)
    main()
def train(rank, train_inputs, dev_inputs, config, args):
    logger = utils.get_root_logger(args, log_fname='log_rank{}'.format(rank))
    if args.n_gpus > 1:
        local_rank = rank
        args.gpu_id = rank
    else:
        local_rank = -1
    if args.n_gpus > 1:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.n_gpus,
                                rank=local_rank)
    set_seed(
        args.gpu_id, args.seed
    )  # in distributed training, this has to be same for all processes

    logger.info('local_rank = {}, n_gpus = {}'.format(local_rank, args.n_gpus))
    logger.info('n_epochs = {}'.format(args.n_epochs))
    if args.gpu_id != -1:
        torch.cuda.set_device(args.gpu_id)

    rtype_distr = config['rtype_distr']
    rtype2idx = config['rtype2idx']
    rtype_distr = {rtype2idx[r]: d for r, d in rtype_distr.items()}

    # dev dataset
    if local_rank in [-1, 0]:
        dev_dataset = PNGPretrainDataset(dev_inputs[0],
                                         dev_inputs[1],
                                         args.task,
                                         is_train=False)
        dev_dataloader = DataLoader(
            dev_dataset,
            batch_size=args.eval_batch_size,  # fix 1 for sampling
            shuffle=False,
            collate_fn=my_dev_collate,
            num_workers=1)  # 1 is safe for hdf5

    # train dataset
    train_dataset = PNGPretrainDataset(train_inputs[0],
                                       train_inputs[1],
                                       args.task,
                                       is_train=True)
    if args.n_gpus > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           num_replicas=args.n_gpus,
                                           rank=local_rank)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,  # fix 1 for sampling
        shuffle=shuffle,
        collate_fn=my_train_collate,
        num_workers=1,
        sampler=train_sampler)  # 1 is safe for hdf5

    # model
    # pos_weight = torch.FloatTensor([args.pos_weight])
    model = get_model(args.from_checkpoint, args.weight_name, args.freeze_lm)
    criterion = model.criterion
    if args.gpu_id != -1:
        model = model.cuda(args.gpu_id)
        # pos_weight = pos_weight.cuda(args.gpu_id)
    if args.n_gpus > 1:
        model = DistributedDataParallel(model,
                                        device_ids=[args.gpu_id],
                                        find_unused_parameters=True)

    optimizer = utils.get_optimizer_adam(model, args.weight_decay, args.lr,
                                         args.adam_epsilon,
                                         args.from_checkpoint,
                                         (args.adam_beta1, args.adam_beta2))
    # optimizer = utils.get_optimizer_adamw(
    #     model, args.weight_decay, args.lr,
    #     args.adam_epsilon, args.from_checkpoint
    # )
    n = len(train_dataset)
    scheduler = utils.get_scheduler(n, optimizer, args.train_batch_size,
                                    args.gradient_accumulation_steps,
                                    args.n_epochs, args.warmup_steps,
                                    args.warmup_portion, args.from_checkpoint)

    if local_rank in [-1, 0]:
        logger.info("***** Running training *****")
        logger.info("  Num Epochs = %d", args.n_epochs)
        logger.info("  Training batch size = %d", args.train_batch_size)
        logger.info("  Evaluation batch size = %d", args.eval_batch_size)
        logger.info("  Accu. train batch size = %d",
                    args.train_batch_size * args.gradient_accumulation_steps)
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Weight Decay = {}".format(args.weight_decay))
        logger.info("  Learning Rate = {}".format(args.lr))

        if args.no_first_eval or args.no_eval:
            best_metric = 0.0
        else:
            best_metric = evaluate(model,
                                   dev_dataloader,
                                   args.gpu_id,
                                   logger=logger)
        logger.info('start dev_metric = {}'.format(best_metric))
        tb_writer = SummaryWriter('{}/explog'.format(args.output_dir))
    else:
        best_metric = 0.0

    t1 = time.time()
    step, accu_step = 0, 0
    prev_acc_loss, acc_loss = 0.0, 0.0
    model.zero_grad()
    for i_epoch in range(args.n_epochs):
        t2 = time.time()
        logger.info('========== Epoch {} =========='.format(i_epoch))

        for batch in train_dataloader:
            batch = batch_sample_truncated_graphs(batch, rtype_distr, args)
            if batch is None:
                # sometimes happen
                logger.warning('unable to sample neg edges, skip this batch')
                continue

            batch = calculate_norms(batch)

            model.train()
            # to GPU
            if args.gpu_id != -1:
                batch = to_gpu(batch, args.gpu_id)

            # forward pass
            all_scores, all_ys, all_rtypes, all_embs = model(**batch)

            loss = criterion(all_scores, all_ys, all_rtypes, all_embs)
            if args.n_gpus > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # backward pass
            loss.backward()

            # accumulation
            acc_loss += loss.item()
            accu_step += 1
            if accu_step % args.gradient_accumulation_steps == 0:  # ignore the last accumulation
                # update params
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                step += 1

                # loss
                if local_rank in [-1, 0]:
                    if args.logging_steps > 0 and step % args.logging_steps == 0:
                        cur_loss = (acc_loss -
                                    prev_acc_loss) / args.logging_steps
                        logger.info(
                            'task={}, train_loss={}, accu_step={}, step={}, time={}s'
                            .format(args.task, cur_loss, accu_step, step,
                                    time.time() - t1))
                        tb_writer.add_scalar('train_loss', cur_loss, step)
                        tb_writer.add_scalar('lr', scheduler.get_last_lr()[0])

                        # evaluate
                        if not args.no_eval:
                            dev_metric = evaluate(model,
                                                  dev_dataloader,
                                                  args.gpu_id,
                                                  logger=logger)
                            logger.info('dev_metric={}'.format(dev_metric))
                            if best_metric < dev_metric:
                                best_metric = dev_metric

                                # save
                                utils.save_model(model, optimizer, scheduler,
                                                 args.output_dir, step)
                        prev_acc_loss = acc_loss

        logger.info('done epoch {}: {} s'.format(i_epoch, time.time() - t2))
        if local_rank in [-1, 0] and args.no_eval:
            logger.info('saving model for epoch {}'.format(i_epoch))
            utils.save_model(model, optimizer, scheduler, args.output_dir,
                             step)
    if local_rank in [-1, 0]:
        tb_writer.close()
    logger.info('best_dev_metric = {}'.format(best_metric))
    logger.info('done training: {} s'.format(time.time() - t1))