예제 #1
0
def main():
    parser = argparse.ArgumentParser(
        description="easy2train for Change Detection")

    parser.add_argument(
        "-cfg",
        "--config_file",
        default="configs/h**o/default.yaml",
        metavar="FILE",
        help="Path to config file",
        type=str,
    )

    parser.add_argument(
        "-se",
        "--skip_eval",
        help="Do not eval the models(checkpoints)",
        action="store_true",
    )

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    run_train(cfg)
예제 #2
0
파일: struct2vec_exp.py 프로젝트: czx94/d2l
def main():
    parser = argparse.ArgumentParser("Struct2vec training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH, create_using=networkx.DiGraph(), nodetype=None, data=[('weight', int)])
    model = Struct2Vec(graph, cfg, logger)

    model.train()
    embedding = model.get_embedding()

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
예제 #3
0
def main():

    parser = argparse.ArgumentParser(
        description="eval models from checkpoints.")

    parser.add_argument(
        "-cfg",
        "--config_file",
        default="configs/h**o/default.yaml",
        metavar="FILE",
        help="Path to config file",
        type=str,
    )

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    run_eval(cfg)
예제 #4
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch NDDR Training")
    parser.add_argument(
        "--config-file",
        default="configs/vgg16_nddr_pret.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # load the data
    test_loader = torch.utils.data.DataLoader(
        MultiTaskDataset(
            data_dir=cfg.DATA_DIR,
            data_list_1=cfg.TEST.DATA_LIST_1,
            data_list_2=cfg.TEST.DATA_LIST_2,
            output_size=cfg.TEST.OUTPUT_SIZE,
            random_scale=cfg.TEST.RANDOM_SCALE,
            random_mirror=cfg.TEST.RANDOM_MIRROR,
            random_crop=cfg.TEST.RANDOM_CROP,
            ignore_label=cfg.IGNORE_LABEL,
        ),
        batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

    net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights='')
    net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights='')
    model = NDDRNet(net1, net2,
                    shortcut=cfg.MODEL.SHORTCUT,
                    bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME, 'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    if cfg.CUDA:
        model = model.cuda()

    mIoU, pixel_acc, angle_metrics = evaluate(test_loader, model)
    print('Mean IoU: {:.3f}'.format(mIoU))
    print('Pixel Acc: {:.3f}'.format(pixel_acc))
    for k, v in angle_metrics.items():
        print('{}: {:.3f}'.format(k, v))
예제 #5
0
def main(device):
    parser = argparse.ArgumentParser("SDNE training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH,
                                   create_using=networkx.DiGraph(),
                                   nodetype=None,
                                   data=[('weight', int)])
    model = SDNE(graph, cfg).to(device)

    L_matrix, adj_matrix = model.get_matrix()
    criterion = MixLoss(cfg, L_matrix, adj_matrix, device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.WORD2VEC.LR,
                                 weight_decay=cfg.SDNE.L2)

    train(adj_matrix, model, cfg, criterion, optimizer, device)
    embedding = model.get_embedding(adj_matrix.to(device))

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
예제 #6
0
def main(device):
    parser = argparse.ArgumentParser("LINE training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger, log_path = create_logger(cfg)
    logger.info(cfg)

    graph = networkx.read_edgelist(cfg.DATA.GRAPH_PATH, create_using=networkx.DiGraph(), nodetype=None, data=[('weight', int)])
    model = LINE(graph, cfg).to(device)
    v2ind, ind2v = model.get_mapping()
    sampler = Sampler(graph, v2ind, batch_size=cfg.SAMPLE.BATCHSIZE)

    criterion = KLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.WORD2VEC.LR)

    train(sampler, model, cfg, criterion, optimizer, device)
    embedding = model.get_embedding()

    eval_embedding(embedding, cfg.DATA.LABEL_PATH, logger)
    vis_embedding(embedding, cfg.DATA.LABEL_PATH, log_path)
예제 #7
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument(
        "--config_file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # build model, optimizer and scheduler
    model = make_model(cfg)
    model = model.to(cfg.DEVICE)
    optimizer = build_optimizer(cfg, model)
    print('optimizer built!')
    # NOTE: add separate optimizers to train single object predictor and interaction predictor

    if cfg.USE_WANDB:
        logger = Logger("FOL", cfg, project=cfg.PROJECT, viz_backend="wandb")
    else:
        logger = logging.Logger("FOL")

    dataloader_params = {
        "batch_size": cfg.SOLVER.BATCH_SIZE,
        "shuffle": True,
        "num_workers": cfg.DATALOADER.NUM_WORKERS
    }

    # get dataloaders
    train_dataloader = make_dataloader(cfg, 'train')
    val_dataloader = make_dataloader(cfg, 'val')
    test_dataloader = make_dataloader(cfg, 'test')
    print('Dataloader built!')
    # get train_val_test engines
    do_train, do_val, inference = build_engine(cfg)
    print('Training engine built!')
    if hasattr(logger, 'run_id'):
        run_id = logger.run_id
    else:
        run_id = 'no_wandb'

    save_checkpoint_dir = os.path.join(cfg.CKPT_DIR, run_id)
    if not os.path.exists(save_checkpoint_dir):
        os.makedirs(save_checkpoint_dir)

    # NOTE: hyperparameter scheduler
    model.param_scheduler = ParamScheduler()
    model.param_scheduler.create_new_scheduler(name='kld_weight',
                                               annealer=sigmoid_anneal,
                                               annealer_kws={
                                                   'device': cfg.DEVICE,
                                                   'start': 0,
                                                   'finish': 100.0,
                                                   'center_step': 400.0,
                                                   'steps_lo_to_hi': 100.0,
                                               })

    model.param_scheduler.create_new_scheduler(name='z_logit_clip',
                                               annealer=sigmoid_anneal,
                                               annealer_kws={
                                                   'device': cfg.DEVICE,
                                                   'start': 0.05,
                                                   'finish': 5.0,
                                                   'center_step': 300.0,
                                                   'steps_lo_to_hi': 300.0 / 5.
                                               })

    if cfg.SOLVER.scheduler == 'exp':
        # exponential schedule
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=cfg.SOLVER.GAMMA)
    elif cfg.SOLVER.scheduler == 'plateau':
        # Plateau scheduler
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            factor=0.2,
                                                            patience=5,
                                                            min_lr=1e-07,
                                                            verbose=1)
    else:
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=[25, 40],
                                                      gamma=0.2)

    print('Schedulers built!')

    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        logger.info("Epoch:{}".format(epoch))
        do_train(cfg,
                 epoch,
                 model,
                 optimizer,
                 train_dataloader,
                 cfg.DEVICE,
                 logger=logger,
                 lr_scheduler=lr_scheduler)
        val_loss = do_val(cfg,
                          epoch,
                          model,
                          val_dataloader,
                          cfg.DEVICE,
                          logger=logger)
        if (epoch + 1) % 1 == 0:
            inference(cfg,
                      epoch,
                      model,
                      test_dataloader,
                      cfg.DEVICE,
                      logger=logger,
                      eval_kde_nll=False)

        torch.save(
            model.state_dict(),
            os.path.join(save_checkpoint_dir,
                         'Epoch_{}.pth'.format(str(epoch).zfill(3))))

        # update LR
        if cfg.SOLVER.scheduler != 'exp':
            lr_scheduler.step(val_loss)
예제 #8
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch NDDR Training")
    parser.add_argument(
        "--config-file",
        default="configs/vgg16_nddr_pret.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if not os.path.exists(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME)):
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_loader = torch.utils.data.DataLoader(
        MultiTaskDataset(
            data_dir=cfg.DATA_DIR,
            data_list_1=cfg.TRAIN.DATA_LIST_1,
            data_list_2=cfg.TRAIN.DATA_LIST_2,
            output_size=cfg.TRAIN.OUTPUT_SIZE,
            random_scale=cfg.TRAIN.RANDOM_SCALE,
            random_mirror=cfg.TRAIN.RANDOM_MIRROR,
            random_crop=cfg.TRAIN.RANDOM_CROP,
            ignore_label=cfg.IGNORE_LABEL,
        ),
        batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True)

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_loader = torch.utils.data.DataLoader(
            MultiTaskDataset(
                data_dir=cfg.DATA_DIR,
                data_list_1=cfg.TEST.DATA_LIST_1,
                data_list_2=cfg.TEST.DATA_LIST_2,
                output_size=cfg.TEST.OUTPUT_SIZE,
                random_scale=cfg.TEST.RANDOM_SCALE,
                random_mirror=cfg.TEST.RANDOM_MIRROR,
                random_crop=cfg.TEST.RANDOM_CROP,
                ignore_label=cfg.IGNORE_LABEL,
            ),
            batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME, timestamp)
    if not os.path.exists(experiment_log_dir):
        os.makedirs(experiment_log_dir)
    writer = SummaryWriter(logdir=experiment_log_dir)

    net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights=cfg.TRAIN.WEIGHT_1)
    net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights=cfg.TRAIN.WEIGHT_2)
    model = NDDRNet(net1, net2,
                    shortcut=cfg.MODEL.SHORTCUT,
                    bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
    
    if cfg.CUDA:
        model = model.cuda()
    model.train()
    steps = 0

    seg_loss = nn.CrossEntropyLoss(ignore_index=255)

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_parameters():
        if 'nddrs' in k:
            nddr_params.append(v)
        elif cfg.MODEL.FC8_ID in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            base_params.append(v)
    assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [
        {'params': base_params},
        {'params': fc8_weights, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR},
        {'params': fc8_bias, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR},
        {'params': nddr_params, 'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR}
    ]
    optimizer = optim.SGD(parameter_dict, lr=cfg.TRAIN.LR, momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    
    if cfg.TRAIN.SCHEDULE == 'Poly':
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER, last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.TRAIN.STEPS)
    else:
        raise NotImplementedError

    while steps < cfg.TRAIN.STEPS:
        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(), label_2.cuda()
            optimizer.zero_grad()
            out1, out2 = model(image)

            # loss_seg = get_seg_loss(out1, label_1, 40, 255)
            loss_seg = seg_loss(out1, label_1.squeeze(1))
            loss_normal = get_normal_loss(out2, label_2, 255)

            loss = loss_seg + cfg.TRAIN.NORMAL_FACTOR * loss_normal

            loss.backward()

            optimizer.step()
            scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0:
                print('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'.format(
                    steps, batch_idx * len(image), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data.item(),
                    loss_seg.data.item(), loss_normal.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                writer.add_scalar('loss/seg', loss_seg.data.item(), steps)
                writer.add_scalar('loss/normal', loss_normal.data.item(), steps)

                writer.add_image('image', process_image(image[0]), steps)
                seg_pred, seg_gt = process_seg_label(
                    out1.argmax(dim=1)[0].detach(),
                    label_1.squeeze(1)[0],
                    cfg.MODEL.NET1_CLASSES
                )
                writer.add_image('seg/pred', seg_pred, steps)
                writer.add_image('seg/gt', seg_gt, steps)
                normal_pred, normal_gt = process_normal_label(out2[0].detach(), label_2[0], 255)
                writer.add_image('normal/pred', normal_pred, steps)
                writer.add_image('normal/gt', normal_gt, steps)

            if steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss_seg': loss_seg,
                    'loss_normal': loss_normal,
                    'mIoU': None,
                    'PixelAcc': None,
                    'angle_metrics': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    mIoU, pixel_acc, angle_metrics = evaluate(test_loader, model)
                    writer.add_scalar('eval/mIoU', mIoU, steps)
                    writer.add_scalar('eval/PixelAcc', pixel_acc, steps)
                    for k, v in angle_metrics.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    checkpoint['mIoU'] = mIoU
                    checkpoint['PixelAcc'] = pixel_acc
                    checkpoint['angle_metrics'] = angle_metrics
                    model.train()

                torch.save(checkpoint, os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                                    'ckpt-%s.pth' % str(steps).zfill(5)))

            steps += 1
            if steps >= cfg.TRAIN.STEPS:
                break
        description="PyTorch Object Detection Training")
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument(
        "--config_file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    main(cfg)

    # NOTE: define all the ckpt we want to check
    # all_ckpts = ['data/ETH_UCY_trajectron/checkpoints/goal_cvae_checkpoints/yh58kusd/Epoch_{}.pth'.format(str(i).zfill(3)) for i in range(1, 16)]
    # for ckpt in all_ckpts:
    #     cfg.CKPT_DIR = ckpt
    #     for min_hist_len in [1, 8]:
    #         cfg.MODEL.MIN_HIST_LEN = min_hist_len
    #         main(cfg)
예제 #10
0
def main():
    # Read config settings from yaml file
    args = get_parser().parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    #create logger
    output_dir = pathlib.Path(cfg.OUTPUT_DIR)
    output_dir.mkdir(exist_ok=True, parents=True)
    logger = setup_logger("U", "logs")
    logger.info("Loaded configuration file {}".format(cfg))

    lr = cfg.SOLVER.LEARN_RATE
    bs = cfg.SOLVER.BATCH_SIZE
    test_bs = cfg.SOLVER.TEST_BATCH_SIZE
    epochs = cfg.TRAINING.EPOCHS

    # sets the matplotlib display backend (most likely not needed)
    # mp.use('TkAgg', force=True)

    data = create_dataset()
    tte_test_dataset, tee_test_dataset = load_test_data()

    #split the training dataset and initialize the data loaders
    print("Train Data length: {}".format(len(data)))
    train_partition = int(cfg.TRAINING.TRAIN_PARTITION * len(data))
    val_partition = len(data) - train_partition

    # split the training dataset and initialize the data loaders
    train_dataset, valid_dataset = torch.utils.data.random_split(
        data, (train_partition, val_partition))
    train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    valid_data = DataLoader(valid_dataset, batch_size=bs, shuffle=True)
    tte_test_data = DataLoader(tte_test_dataset,
                               batch_size=test_bs,
                               shuffle=True)
    tee_test_data = DataLoader(tee_test_dataset,
                               batch_size=test_bs,
                               shuffle=True)

    if cfg.TRAINING.VISUAL_DEBUG:
        plotter.plot_image_and_mask(data, 1)
    xb, yb = next(iter(train_data))
    print(xb.shape, yb.shape)

    # build the Unet2D with default input and output channels as given by config

    unet = Unet2D()
    # logger.info(unet)

    #loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(unet.parameters(), lr=lr)

    #do some training
    if cfg.TESTING.TEST_ONLY:
        run_test(unet, opt, tee_test_data)
        evauate_and_log_results(logger, unet, tte_test_data, tee_test_data)
        return
    else:
        train_loss, valid_loss, valid_dice, valid_dice_per_class, pixel_acc = start_train(
            unet,
            train_data,
            valid_data,
            loss_fn,
            opt,
            dice_multiclass,
            acc_metric,
            epochs=epochs)

    # Evaluate network(f)
    evauate_and_log_results(logger, unet, tte_test_data, tee_test_data)

    if cfg.TRAINING.VISUAL_DEBUG:
        # plot training and validation losses
        plotter.plot_train_and_val_loss(train_loss, valid_loss)

        # plot validation pixel accuracy
        plotter.plot_avg_dice_history(valid_dice)

        # plot validation dice
        plotter.plot_dice_history(valid_dice_per_class)

        # show the predicted segmentations
        plotter.predict_on_batch_and_plot(tte_test_data, unet, test_bs)
        plotter.predict_on_batch_and_plot(tee_test_data, unet, test_bs)
예제 #11
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Transfer Learning Task")
    parser.add_argument("--config",
                        default="",
                        metavar="FILE",
                        help="path to config file",
                        type=str)
    parser.add_argument("--load_ckpt",
                        default=None,
                        metavar="FILE",
                        help="path to the pretrained model to load ",
                        type=str)
    parser.add_argument(
        '--test_only',
        action='store_true',
        help="test the model (need to provide model path to --load_ckpt)")
    parser.add_argument("opts",
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    if args.test_only:
        assert os.path.exists(
            args.load_ckpt
        ), "You need to provide the valid model path using --load_ckpt"

    cfg.merge_from_file(args.config)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    np.random.seed(1234)
    torch.manual_seed(1234)
    if 'cuda' in device:
        torch.cuda.manual_seed_all(1234)

    if not os.path.exists(cfg.OUTPUT_DIR):
        print('creating output dir: {}'.format(cfg.OUTPUT_DIR))
        os.makedirs(cfg.OUTPUT_DIR)

    logger = setup_logger("fewshot", cfg.OUTPUT_DIR)
    logger.info("Loaded configuration file {}".format(args.config))
    with open(args.config, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
    logger.info("Saving config into: {}".format(output_config_path))
    # save overloaded model config in the output directory
    save_config(cfg, output_config_path)

    model = architectures[cfg.MODEL.ARCHITECTURE](cfg, logger)
    model.to(device)

    meta_test_dataset = FashionProductImagesFewShot(cfg=cfg,
                                                    split='test',
                                                    k_way=cfg.TEST.K_WAY,
                                                    n_shot=cfg.TEST.N_SHOT,
                                                    logger=logger)
    meta_test_dataset.params['shuffle'] = False
    meta_test_dataset.params['batch_size'] = 1
    meta_test_dataloader = DataLoader(meta_test_dataset,
                                      **meta_test_dataset.params)

    if args.test_only:
        do_test(cfg, model, meta_test_dataloader, logger, 'test',
                args.load_ckpt)
    else:
        meta_train_dataset = FashionProductImagesFewShot(
            cfg=cfg,
            split='train',
            k_way=cfg.TRAIN.K_WAY,
            n_shot=cfg.TRAIN.N_SHOT,
            logger=logger)
        meta_train_dataloader = DataLoader(meta_train_dataset,
                                           **meta_train_dataset.params)

        do_train(cfg, model, meta_train_dataloader, logger, args.load_ckpt)
        do_test(cfg, model, meta_test_dataloader, logger, 'test', None)
    '''for inner_iter, data in enumerate(meta_test_dataloader):