예제 #1
0
    def __init__(self, opt=None):
        self.opt = opt
        # self.root = '/scr1/system/beta-robot/base_eval'
        # self.root = os.path.join(os.getcwd(),'base_eval')
        if __name__ == '__main__':
            self.root = '/scr1/system/gamma-robot/scripts/Eval/base_eval'
        else:
            self.root = os.path.join(self.opt.project_root, 'scripts', 'Eval',
                                     'base_eval')

        self.config = load_json_config(
            os.path.join(self.root,
                         "configs/pretrained/config_model1_left_right.json"))

        # set column model
        # file_name = self.config['conv_model']
        # self.cnn_def = importlib.import_module ("{}".format (file_name))

        # setup device - CPU or GPU
        self.device = torch.device("cuda")
        self.device_ids = [0]

        # model_name = "model3D_1"
        # create model
        # model = MultiColumn (self.config["num_classes"], self.cnn_def.Model, int (self.config["column_units"]))
        self.model = MultiColumn(self.config["num_classes"], Model,
                                 int(self.config["column_units"]))
        # multi GPU setting
        self.model = torch.nn.DataParallel(self.model,
                                           self.device_ids).to(self.device)

        try:
            if self.opt.use_refine_baseline:
                save_dir = os.path.join(
                    os.path.join(self.root,
                                 "trained_models/pretrained/" + "refined"))
            else:
                save_dir = os.path.join(
                    os.path.join(
                        self.root, "trained_models/pretrained/" +
                        self.config['model_name']))
        except:
            save_dir = os.path.join(
                os.path.join(
                    self.root,
                    "trained_models/pretrained/" + self.config['model_name']))
        checkpoint_path = os.path.join(save_dir, 'model_best.pth.tar')

        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()
예제 #2
0
def get_reward(filepath):
    model_name = "model3D_1"
    # create model
    model = MultiColumn(config["num_classes"], cnn_def.Model,
                        int(config["column_units"]))
    # multi GPU setting
    model = torch.nn.DataParallel(model, device_ids).to(device)

    save_dir = os.path.join(
        "/scr1/system/beta-robot/smth-smth-v2-baseline-with-models/trained_models/pretrained/"
        + config['model_name'])
    checkpoint_path = os.path.join(save_dir, 'model_best.pth.tar')

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    logits_matrix = []
    features_matrix = []
    targets_list = []
    item_id_list = []

    with torch.no_grad():
        input = videoloader(filepath)
        input = input.float().unsqueeze(0)
        input_var = [input.to(device)]
        output = model(input_var, False)
        output = F.softmax(output, 1)
        output = output.cpu().detach().numpy()
        output = np.squeeze(output)

        output_index = np.argsort(output * -1.0)

        return (output, output_index)
예제 #3
0
    elif config["input_mode"] == "uiuc":
        from data_loader_uiuc import VideoFolder
    else:
        raise ValueError("Please provide a valid input mode")

    # set run output folder
    model_name = config["model_name"]
    output_dir = config["output_dir"]
    save_dir = os.path.join(output_dir, model_name)

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, ExperimentalRunCleaner(save_dir))

    # create model
    print(" > Creating model ... !")
    model = MultiColumn(config['num_classes'], cnn_def.Model,
                        int(config["column_units"]))

    # multi GPU setting
    model = torch.nn.DataParallel(model, device_ids).to(device)

    # optionally resume from a checkpoint
    checkpoint_path = os.path.join(config['output_dir'], config['model_name'],
                                   'model_best.pth.tar')

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        model.load_state_dict(checkpoint['state_dict'])
        print(" > Loaded checkpoint '{}' (epoch {})".format(
            checkpoint_path, checkpoint['epoch']))
예제 #4
0
def main():
    global args, best_loss

    # set run output folder
    model_name = config["model_name"]
    output_dir = config["output_dir"]
    save_dir = os.path.join(output_dir, model_name)
    print(" > Output folder for this run -- {}".format(save_dir))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, ExperimentalRunCleaner(save_dir))

    # create model
    print(" > Creating model ... !")
    model = MultiColumn(config['num_classes'], cnn_def.Model,
                        int(config["column_units"]))

    # multi GPU setting
    model = torch.nn.DataParallel(model, device_ids).to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    lr_decayer = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            'min',
                                                            factor=0.5,
                                                            patience=2,
                                                            verbose=True)

    # optionally resume from a checkpoint
    checkpoint_path = os.path.join(config['output_dir'], config['model_name'],
                                   'model_best.pth.tar')
    if args.resume:
        if os.path.isfile(checkpoint_path):
            print(" > Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(checkpoint_path)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            lr_decayer.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # for state in optimizer.state.values():
            #     for k, v in state.items():
            #         if isinstance(v, torch.Tensor):
            #             state[k] = v.to(device)
            print(" > Loaded checkpoint '{}' (epoch {})".format(
                checkpoint_path, checkpoint['epoch']))
        else:
            print(" !#! No checkpoint found at '{}'".format(checkpoint_path))
    elif config.get('finetune_from') is not None:
        print(' > Loading checkpoint to finetune')
        finetune_model_name = config['finetune_from']
        checkpoint_path = os.path.join(config['output_dir'],
                                       finetune_model_name,
                                       'model_best.pth.tar')
        checkpoint = torch.load(checkpoint_path)
        model.module.clf_layers = nn.Sequential(
            nn.Linear(model.module.column_units, 174)).to(device)
        model.load_state_dict(checkpoint['state_dict'])
        model.module.clf_layers = nn.Sequential(
            nn.Linear(model.module.column_units,
                      config['num_classes'])).to(device)
        print(" > Loaded checkpoint '{}' (epoch {}))".format(
            checkpoint_path, checkpoint['epoch']))
        # Freeze first 3 blocks
        for param in model.module.conv_column.block1.parameters():
            param.requires_grad = False
        for param in model.module.conv_column.block2.parameters():
            param.requires_grad = False
        for param in model.module.conv_column.block3.parameters():
            param.requires_grad = False

    # define augmentation pipeline
    upscale_size_train = int(config['input_spatial_size'] *
                             config["upscale_factor_train"])
    upscale_size_eval = int(config['input_spatial_size'] *
                            config["upscale_factor_eval"])

    # Random crop videos during training
    transform_train_pre = ComposeMix([
        [RandomRotationVideo(15), "vid"],
        [Scale(upscale_size_train), "img"],
        [RandomCropVideo(config['input_spatial_size']), "vid"],
    ])

    # Center crop videos during evaluation
    transform_eval_pre = ComposeMix([
        [Scale(upscale_size_eval), "img"],
        [torchvision.transforms.ToPILImage(), "img"],
        [
            torchvision.transforms.CenterCrop(config['input_spatial_size']),
            "img"
        ],
    ])

    # Transforms common to train and eval sets and applied after "pre" transforms
    transform_post = ComposeMix([
        [torchvision.transforms.ToTensor(), "img"],
        [
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # default values for imagenet
                std=[0.229, 0.224, 0.225]),
            "img"
        ]
    ])

    train_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_train'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_train'],
        step_size=config['step_size_train'],
        is_val=False,
        transform_pre=transform_train_pre,
        transform_post=transform_post,
        augmentation_mappings_json=config['augmentation_mappings_json'],
        augmentation_types_todo=config['augmentation_types_todo'],
        get_item_id=False,
    )

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True)

    val_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_val'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_val'],
        step_size=config['step_size_val'],
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
    )

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    test_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_test'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_val'],
        step_size=config['step_size_val'],
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
        is_test=True,
    )

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=False)

    print(" > Number of dataset classes : {}".format(len(train_data.classes)))
    assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(device)

    if args.eval_only:
        validate(val_loader, model, criterion, train_data.classes_dict)
        print(" > Evaluation DONE !")
        return

    # set callbacks
    plotter = PlotLearning(os.path.join(save_dir, "plots"),
                           config["num_classes"])
    val_loss = float('Inf')

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    print(" > Training is getting started...")
    print(" > Training takes {} epochs.".format(num_epochs))
    start_epoch = args.start_epoch if args.resume else 0

    for epoch in range(start_epoch, num_epochs):

        lrs = [params['lr'] for params in optimizer.param_groups]
        print(" > Current LR(s) -- {}".format(lrs))
        if np.max(lr) < last_lr and last_lr > 0:
            print(" > Training is DONE by learning rate {}".format(last_lr))
            sys.exit(1)

        with experiment.train():
            # train for one epoch
            train_loss, train_top1, train_top5 = train(train_loader, model,
                                                       criterion, optimizer,
                                                       epoch)
            metrics = {
                'avg_loss': train_loss,
                'avg_top1': train_top1,
                'avg_top5': train_top5,
            }
            experiment.log_metrics(metrics)

        with experiment.validate():
            # evaluate on validation set
            val_loss, val_top1, val_top5 = validate(val_loader, model,
                                                    criterion)
            metrics = {
                'avg_loss': val_loss,
                'avg_top1': val_top1,
                'avg_top5': val_top5,
            }
            experiment.log_metrics(metrics)
        experiment.log_metric('epoch', epoch)

        # set learning rate
        lr_decayer.step(val_loss, epoch)

        # plot learning
        plotter_dict = {}
        plotter_dict['loss'] = train_loss
        plotter_dict['val_loss'] = val_loss
        plotter_dict['acc'] = train_top1 / 100
        plotter_dict['val_acc'] = val_top1 / 100
        plotter_dict['learning_rate'] = lr
        plotter.plot(plotter_dict)

        print(" > Validation loss after epoch {} = {}".format(epoch, val_loss))

        # remember best loss and save the checkpoint
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "Conv4Col",
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': lr_decayer.state_dict(),
                'best_loss': best_loss,
            }, is_best, config)
예제 #5
0
class Base_eval:
    def __init__(self, opt=None):
        self.opt = opt
        # self.root = '/scr1/system/beta-robot/base_eval'
        # self.root = os.path.join(os.getcwd(),'base_eval')
        if __name__ == '__main__':
            self.root = '/scr1/system/gamma-robot/scripts/Eval/base_eval'
        else:
            self.root = os.path.join(self.opt.project_root, 'scripts', 'Eval',
                                     'base_eval')

        self.config = load_json_config(
            os.path.join(self.root,
                         "configs/pretrained/config_model1_left_right.json"))

        # set column model
        # file_name = self.config['conv_model']
        # self.cnn_def = importlib.import_module ("{}".format (file_name))

        # setup device - CPU or GPU
        self.device = torch.device("cuda")
        self.device_ids = [0]

        # model_name = "model3D_1"
        # create model
        # model = MultiColumn (self.config["num_classes"], self.cnn_def.Model, int (self.config["column_units"]))
        self.model = MultiColumn(self.config["num_classes"], Model,
                                 int(self.config["column_units"]))
        # multi GPU setting
        self.model = torch.nn.DataParallel(self.model,
                                           self.device_ids).to(self.device)

        try:
            if self.opt.use_refine_baseline:
                save_dir = os.path.join(
                    os.path.join(self.root,
                                 "trained_models/pretrained/" + "refined"))
            else:
                save_dir = os.path.join(
                    os.path.join(
                        self.root, "trained_models/pretrained/" +
                        self.config['model_name']))
        except:
            save_dir = os.path.join(
                os.path.join(
                    self.root,
                    "trained_models/pretrained/" + self.config['model_name']))
        checkpoint_path = os.path.join(save_dir, 'model_best.pth.tar')

        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()

    def videoloader(self, filepath):
        transform_pre = ComposeMix([
            [Scale(int(1.4 * self.config['input_spatial_size'])), "img"],
            [torchvision.transforms.ToPILImage(), "img"],
            [
                torchvision.transforms.CenterCrop(
                    self.config['input_spatial_size']), "img"
            ],
        ])

        transform_post = ComposeMix([
            [torchvision.transforms.ToTensor(), "img"],
            [
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],  # default values for imagenet
                    std=[0.229, 0.224, 0.225]),
                "img"
            ]
        ])

        imgs = []
        for file in os.listdir(filepath):
            file_id = int(file.split('.')[0])
            tmp = cv2.imread(os.path.join(filepath, file))
            tmp = cv2.cvtColor(tmp, cv2.COLOR_BGR2RGB)
            imgs.append([tmp, file_id])

        imgs = sorted(imgs, key=lambda x: x[1])
        imgs = [x[0] for x in imgs]

        imgs = transform_pre(imgs)
        imgs = transform_post(imgs)

        # num_frames = len (imgs)
        # num_frames_necessary = 72

        if len(imgs) < 72:
            imgs.extend([imgs[-1]] * (72 - len(imgs)))

        data = torch.stack(imgs)
        data = data.permute(1, 0, 2, 3)
        return data

    def memroy_loader(self, imgs):
        transform_pre = ComposeMix([
            [Scale(int(1.4 * self.config['input_spatial_size'])), "img"],
            [torchvision.transforms.ToPILImage(), "img"],
            [
                torchvision.transforms.CenterCrop(
                    self.config['input_spatial_size']), "img"
            ],
        ])

        transform_post = ComposeMix([
            [torchvision.transforms.ToTensor(), "img"],
            [
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],  # default values for imagenet
                    std=[0.229, 0.224, 0.225]),
                "img"
            ]
        ])

        imgs = transform_pre(imgs)
        imgs = transform_post(imgs)

        # num_frames = len (imgs)
        # num_frames_necessary = 72

        if len(imgs) < 72:
            imgs.extend([imgs[-1]] * (72 - len(imgs)))

        data = torch.stack(imgs)
        data = data.permute(1, 0, 2, 3)
        return data

    def get_baseline_reward(self, filepath):
        with torch.no_grad():
            input = self.videoloader(filepath)
            input = input.float().unsqueeze(0)
            input_var = [input.to(self.device)]
            output = self.model(input_var, False)
            output = F.softmax(output, 1)
            output = output.cpu().detach().numpy()
            output = np.squeeze(output)

            output_index = np.argsort(output * -1.0)

            return (output, output_index)

    def get_memory_reward(self, img_buffer):
        with torch.no_grad():
            img_buffer = self.memroy_loader(img_buffer)
            input = img_buffer.float().unsqueeze(0)
            input_var = [input.to(self.device)]
            output = self.model(input_var, False)
            output = F.softmax(output, 1)
            output = output.cpu().detach().numpy()
            output = np.squeeze(output)

            output_index = np.argsort(output * -1.0)

            return (output, output_index)
예제 #6
0
def main():
    global args, best_loss

    # set run output folder
    config['model_id'] = '_'.join([config["model_name"], args.job_identifier])
    wandb.init(project="cross-dataset-generalization", config=config)
    output_dir = config["output_dir"]
    save_dir = os.path.join(output_dir, config['model_id'])
    print(" > Output folder for this run -- {}".format(save_dir))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, utils.ExperimentalRunCleaner(save_dir))

    # create model
    print(" > Creating model ... !")
    if '3D' in config['model_name']:
        model = MultiColumn(config['num_classes'], model_def.Model,
                            int(config["column_units"]))

        # multi GPU setting
        model = torch.nn.DataParallel(model, device_ids).to(device)
        input_size = (config['batch_size'], 3, config['clip_size'],
                      config['input_spatial_size'],
                      config['input_spatial_size'])
        seq_first = False
    else:
        model = model_def.ConvLSTMModel(config=config)
        input_size = (config['clip_size'], config['batch_size'], 3,
                      config['input_spatial_size'],
                      config['input_spatial_size'])
        seq_first = True

    # Print model summary
    # ts_summary(model, input_size=input_size)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(config['checkpoint_path']):
            print(" > Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(config['checkpoint_path'])
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print(" > Loaded checkpoint '{}' (epoch {})".format(
                config['checkpoint_path'], checkpoint['epoch']))
        else:
            print(" !#! No checkpoint found at '{}'".format(
                config['checkpoint_path']))

    # define augmentation pipeline
    upscale_size_train = int(config['input_spatial_size'] *
                             config["upscale_factor_train"])
    upscale_size_eval = int(config['input_spatial_size'] *
                            config["upscale_factor_eval"])

    # Random crop videos during training
    transform_train_pre = ComposeMix([
        [RandomRotationVideo(15), "vid"],
        [Scale(upscale_size_train), "img"],
        [RandomCropVideo(config['input_spatial_size']), "vid"],
    ])

    # Center crop videos during evaluation
    transform_eval_pre = ComposeMix([
        [Scale(upscale_size_eval), "img"],
        [torchvision.transforms.ToPILImage(), "img"],
        [
            torchvision.transforms.CenterCrop(config['input_spatial_size']),
            "img"
        ],
    ])

    # Transforms common to train and eval sets and applied after "pre" transforms
    transform_post = ComposeMix([
        [torchvision.transforms.ToTensor(), "img"],
        [
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # default values for imagenet
                std=[0.229, 0.224, 0.225]),
            "img"
        ]
    ])

    train_val_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_train'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_train_val'],
        step_size=config['step_size_train_val'],
        is_val=False,
        transform_pre=transform_train_pre,
        transform_post=transform_post,
        augmentation_mappings_json=config['augmentation_mappings_json'],
        augmentation_types_todo=config['augmentation_types_todo'],
        get_item_id=True,
        seq_first=seq_first)
    train_data, val_data = torch.utils.data.random_split(
        train_val_data, [config['nb_train_samples'], config['nb_val_samples']],
        generator=torch.Generator().manual_seed(42))

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    test_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_test'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_test'],
        step_size=config['step_size_test'],
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
        is_test=True,
    )

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=False)

    # print(" > Number of dataset classes : {}".format(len(train_data.classes)))
    # assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    if args.eval_only:
        validate(test_loader, model, criterion, train_data.classes_dict)
        print(" > Evaluation DONE !")
        return

    # set callbacks
    # plotter = PlotLearning(os.path.join(
    #     save_dir, "plots"), config["num_classes"])
    lr_decayer = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            'min',
                                                            factor=0.5,
                                                            patience=2,
                                                            verbose=True)
    val_loss = float('Inf')

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    print(" > Training is getting started...")
    print(" > Training takes {} epochs.".format(num_epochs))
    start_epoch = args.start_epoch if args.resume else 0

    for epoch in range(start_epoch, num_epochs):

        lrs = [params['lr'] for params in optimizer.param_groups]
        print(" > Current LR(s) -- {}".format(lrs))
        if np.max(lr) < last_lr and last_lr > 0:
            print(" > Training is DONE by learning rate {}".format(last_lr))
            sys.exit(1)
        wandb.log({'epoch': epoch})

        # train for one epoch
        train_loss, train_top1, train_top5 = train(train_loader, model,
                                                   criterion, optimizer, epoch)

        # evaluate on validation set
        val_loss, val_top1, val_top5 = validate(val_loader,
                                                model,
                                                criterion,
                                                which_split='val')

        # set learning rate
        lr_decayer.step(val_loss, epoch)

        # # plot learning
        # plotter_dict = {}
        # plotter_dict['loss'] = train_loss
        # plotter_dict['val_loss'] = val_loss
        # plotter_dict['acc'] = train_top1 / 100
        # plotter_dict['val_acc'] = val_top1 / 100
        # plotter_dict['learning_rate'] = lr
        # plotter.plot(plotter_dict)

        print(" > Validation loss after epoch {} = {}".format(epoch, val_loss))

        # remember best loss and save the checkpoint
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "Conv4Col",
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
            }, is_best, config)

    test_loss, test_top1, test_top5 = validate(test_loader,
                                               model,
                                               criterion,
                                               which_split='test')