Exemplo n.º 1
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform,
                   **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=loader.output_size)
    network = nn.SyncBatchNorm.convert_sync_batchnorm(network) if args.convert_to_sync_bn else network
    network = torch.jit.script(network) if args.jit else network
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher
Exemplo n.º 2
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    resize_shape = (args.image_size_override, args.image_size_override)
    transform = [torchvision.transforms.Resize(resize_shape)] \
        if args.image_size_override else None
    loader = get_loader(args, transform=transform, **vars(args))  # build the loader
    args.input_shape = loader.img_shp if args.image_size_override is None \
        else [loader.img_shp[0], *resize_shape]                   # set the input size

    # build the network
    vae_dict = {
        'simple': SimpleVAE,
        'msg': MSGVAE,
        'parallel': ParallellyReparameterizedVAE,
        'sequential': SequentiallyReparameterizedVAE,
        'vrnn': VRNN
    }
    network = vae_dict[args.vae_type](loader.img_shp, kwargs=deepcopy(vars(args)))
    lazy_generate_modules(network, loader.train_loader)
    network = network.cuda() if args.cuda else network
    network = append_save_and_load_fns(network, prefix="VAE_")
    if args.ngpu > 1:
        print("data-paralleling...")
        network.parallel()

    # build the grapher object
    if args.visdom_url:
        grapher = Grapher('visdom', env=get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name(args))

    return loader, network, grapher
Exemplo n.º 3
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    resize_shape = (args.image_size_override, args.image_size_override)
    transform = [transforms.Resize(resize_shape)] \
        if args.image_size_override else None
    loader = get_loader(args, transform=transform,
                        **vars(args))  # build the loader
    args.input_shape = loader.img_shp if args.image_size_override is None \
        else [loader.img_shp[0], *resize_shape]                   # set the input size

    # build the network; to use your own model import and construct it here
    network = resnet18(num_classes=loader.output_size)
    lazy_generate_modules(network, loader.train_loader)
    network = network.cuda() if args.cuda else network
    network = append_save_and_load_fns(network, prefix="VAE_")
    if args.ngpu > 1:
        print("data-paralleling...")
        network.parallel()

    # build the grapher object
    if args.visdom_url:
        grapher = Grapher('visdom',
                          env=get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name(args))

    return loader, network, grapher
Exemplo n.º 4
0
def get_model_and_loader():
    ''' helper to return the model and the loader '''
    aux_transform = None
    if args.synthetic_upsample_size > 0 and args.task == "multi_image_folder":
        aux_transform = lambda x: F.interpolate(
            torchvision.transforms.ToTensor()(x).unsqueeze(0),
            size=(args.synthetic_upsample_size, args.synthetic_upsample_size),
            mode='bilinear',
            align_corners=True).squeeze(0)

    # resizer = torchvision.transforms.Resize(size=(args.synthetic_upsample_size,
    #                                               args.synthetic_upsample_size))
    loader = get_loader(
        args,
        transform=None,  #transform=[resizer],
        sequentially_merge_test=False,
        aux_transform=aux_transform,
        postfix="_large",
        **vars(args))

    # append the image shape to the config & build the VAE
    args.img_shp = loader.img_shp
    vae = VRNN(
        loader.img_shp,
        n_layers=2,  # XXX: hard coded
        #bidirectional=True,    # XXX: hard coded
        bidirectional=False,  # XXX: hard coded
        kwargs=vars(args))

    # build the Variational Saccading module
    # and lazy generate the non-constructed modules
    saccader = Saccader(vae, loader.output_size, kwargs=vars(args))
    lazy_generate_modules(saccader, loader.train_loader)

    # FP16-ize, cuda-ize and parallelize (if requested)
    saccader = saccader.fp16() if args.half is True else saccader
    saccader = saccader.cuda() if args.cuda is True else saccader
    saccader.parallel() if args.ngpu > 1 else saccader

    # build the grapher object (tensorboard or visdom)
    # and plot config json to visdom
    if args.visdom_url is not None:
        grapher = Grapher('visdom',
                          env=saccader.get_name(),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=saccader.get_name())

    grapher.add_text('config',
                     pprint.PrettyPrinter(indent=4).pformat(saccader.config),
                     0)

    # register_nan_checks(saccader)
    return [saccader, loader, grapher]
Exemplo n.º 5
0
def _set_model_indices(model, grapher, idx, args):
    def _init_vae(img_shp, config):
        if args.vae_type == 'sequential':
            # Sequential : P(y|x) --> P(z|y, x) --> P(x|z)
            # Keep a separate VAE spawn here in case we want
            # to parameterize the sequence of reparameterizers
            vae = SequentiallyReparameterizedVAE(img_shp, **{'kwargs': config})
        elif args.vae_type == 'parallel':
            # Ours: [P(y|x), P(z|x)] --> P(x | z)
            vae = ParallellyReparameterizedVAE(img_shp, **{'kwargs': config})
        else:
            raise Exception("unknown VAE type requested")

        return vae

    if idx > 0:  # create some clean models to later load in params
        model.current_model = idx
        if not args.disable_augmentation:
            model.ratio = idx / (idx + 1.0)
            num_teacher_samples = int(args.batch_size * model.ratio)
            num_student_samples = max(args.batch_size - num_teacher_samples, 1)
            print("#teacher_samples: ", num_teacher_samples,
                  " | #student_samples: ", num_student_samples)

            # copy args and reinit clean models for student and teacher
            config_base = vars(args)
            config_teacher = deepcopy(config_base)
            config_student = deepcopy(config_base)
            config_teacher['discrete_size'] += idx - 1
            config_student['discrete_size'] += idx
            model.student = _init_vae(model.student.input_shape,
                                      config_student)
            if not args.disable_student_teacher:
                model.teacher = _init_vae(model.student.input_shape,
                                          config_teacher)

        # re-init grapher
        grapher = Grapher(env=model.get_name(),
                          server=args.visdom_url,
                          port=args.visdom_port)

    return model, grapher
Exemplo n.º 6
0
def get_model_and_loader():
    ''' helper to return the model and the loader '''
    if args.disable_sequential:  # vanilla batch training
        loaders = get_loader(args)
        loaders = [loaders] if not isinstance(loaders, list) else loaders
    else:  # classes split
        loaders = get_split_data_loaders(args, num_classes=10)

    for l in loaders:
        print("train = ", num_samples_in_loader(l.train_loader), " | test = ",
              num_samples_in_loader(l.test_loader))

    # append the image shape to the config & build the VAE
    args.img_shp = loaders[0].img_shp,
    if args.vae_type == 'sequential':
        # Sequential : P(y|x) --> P(z|y, x) --> P(x|z)
        # Keep a separate VAE spawn here in case we want
        # to parameterize the sequence of reparameterizers
        vae = SequentiallyReparameterizedVAE(loaders[0].img_shp,
                                             kwargs=vars(args))
    elif args.vae_type == 'parallel':
        # Ours: [P(y|x), P(z|x)] --> P(x | z)
        vae = ParallellyReparameterizedVAE(loaders[0].img_shp,
                                           kwargs=vars(args))
    else:
        raise Exception("unknown VAE type requested")

    # build the combiner which takes in the VAE as a parameter
    # and projects the latent representation to the output space
    student_teacher = StudentTeacher(vae, kwargs=vars(args))
    #student_teacher = init_weights(student_teacher)

    # build the grapher object
    grapher = Grapher(env=student_teacher.get_name(),
                      server=args.visdom_url,
                      port=args.visdom_port)

    return [student_teacher, loaders, grapher]
Exemplo n.º 7
0
def train_loop(data_loaders, model, fid_model, grapher, args):
    ''' simple helper to run the entire train loop; not needed for eval modes'''
    optimizer = build_optimizer(model.student)  # collect our optimizer
    print(
        "there are {} params with {} elems in the st-model and {} params in the student with {} elems"
        .format(len(list(model.parameters())), number_of_parameters(model),
                len(list(model.student.parameters())),
                number_of_parameters(model.student)))

    # main training loop
    fisher = None
    for j, loader in enumerate(data_loaders):
        num_epochs = args.epochs  # TODO: randomize epochs by something like: + np.random.randint(0, 13)
        print("training current distribution for {} epochs".format(num_epochs))
        early = EarlyStopping(
            model, max_steps=50,
            burn_in_interval=None) if args.early_stop else None
        #burn_in_interval=int(num_epochs*0.2)) if args.early_stop else None

        test_loss = None
        for epoch in range(1, num_epochs + 1):
            train(epoch, model, fisher, optimizer, loader.train_loader,
                  grapher)
            test_loss = test(epoch, model, fisher, loader.test_loader, grapher)
            if args.early_stop and early(test_loss['loss_mean']):
                early.restore()  # restore and test+generate again
                test_loss = test_and_generate(epoch, model, fisher, loader,
                                              grapher)
                break

            generate(model, grapher, 'student')  # generate student samples
            generate(model, grapher, 'teacher')  # generate teacher samples

        # evaluate and save away one-time metrics, these include:
        #    1. test elbo
        #    2. FID
        #    3. consistency
        #    4. num synth + num true samples
        #    5. dump config to visdom
        check_or_create_dir(os.path.join(args.output_dir))
        append_to_csv([test_loss['elbo_mean']],
                      os.path.join(args.output_dir,
                                   "{}_test_elbo.csv".format(args.uid)))
        append_to_csv([test_loss['elbo_mean']],
                      os.path.join(args.output_dir,
                                   "{}_test_elbo.csv".format(args.uid)))
        num_synth_samples = np.ceil(epoch * args.batch_size * model.ratio)
        num_true_samples = np.ceil(epoch * (args.batch_size -
                                            (args.batch_size * model.ratio)))
        append_to_csv([num_synth_samples],
                      os.path.join(args.output_dir,
                                   "{}_numsynth.csv".format(args.uid)))
        append_to_csv([num_true_samples],
                      os.path.join(args.output_dir,
                                   "{}_numtrue.csv".format(args.uid)))
        append_to_csv([epoch],
                      os.path.join(args.output_dir,
                                   "{}_epochs.csv".format(args.uid)))
        grapher.vis.text(num_synth_samples,
                         opts=dict(title="num_synthetic_samples"))
        grapher.vis.text(num_true_samples, opts=dict(title="num_true_samples"))
        grapher.vis.text(pprint.PrettyPrinter(indent=4).pformat(
            model.student.config),
                         opts=dict(title="config"))

        # calc the consistency using the **PREVIOUS** loader
        if j > 0:
            append_to_csv(
                calculate_consistency(model, data_loaders[j - 1],
                                      args.reparam_type, args.vae_type,
                                      args.cuda),
                os.path.join(args.output_dir,
                             "{}_consistency.csv".format(args.uid)))

        if args.calculate_fid_with is not None:
            # TODO: parameterize num fid samples, currently use less for inceptionv3 as it's COSTLY
            num_fid_samples = 4000 if args.calculate_fid_with != 'inceptionv3' else 1000
            append_to_csv(
                calculate_fid(fid_model=fid_model,
                              model=model,
                              loader=loader,
                              grapher=grapher,
                              num_samples=num_fid_samples,
                              cuda=args.cuda),
                os.path.join(args.output_dir, "{}_fid.csv".format(args.uid)))

        grapher.save()  # save the remote visdom graphs
        if j != len(data_loaders) - 1:
            if args.ewc_gamma > 0:
                # calculate the fisher from the previous data loader
                print("computing fisher info matrix....")
                fisher_tmp = estimate_fisher(
                    model.student,  # this is pre-fork
                    loader,
                    args.batch_size,
                    cuda=args.cuda)
                if fisher is not None:
                    assert len(fisher) == len(
                        fisher_tmp), "#fisher params != #new fisher params"
                    for (kf, vf), (kft, vft) in zip(fisher.items(),
                                                    fisher_tmp.items()):
                        fisher[kf] += fisher_tmp[kft]
                else:
                    fisher = fisher_tmp

            # spawn a new student & rebuild grapher; we also pass
            # the new model's parameters through a new optimizer.
            if not args.disable_student_teacher:
                model.fork()
                lazy_generate_modules(model, data_loaders[0].img_shp)
                optimizer = build_optimizer(model.student)
                print(
                    "there are {} params with {} elems in the st-model and {} params in the student with {} elems"
                    .format(len(list(model.parameters())),
                            number_of_parameters(model),
                            len(list(model.student.parameters())),
                            number_of_parameters(model.student)))

            else:
                # increment anyway for vanilla models
                # so that we can have a separate visdom env
                model.current_model += 1

            grapher = Grapher(env=model.get_name(),
                              server=args.visdom_url,
                              port=args.visdom_port)
Exemplo n.º 8
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform, **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.output_size = loader.output_size
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = build_vae(args.vae_type)(loader.input_shape, kwargs=deepcopy(vars(args)))
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # add the test set as a np array for metrics calc
    if args.metrics_server is not None:
        network.test_images = get_numpy_dataset(task=args.task,
                                                data_dir=args.data_dir,
                                                test_transform=test_transform,
                                                split='test',
                                                image_size=args.image_size_override,
                                                cuda=args.cuda)
        print("Metrics test images: ", network.test_images.shape)

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher
Exemplo n.º 9
0
def get_model_and_loader():
    ''' helper to return the model and the loader '''
    aux_transform = None
    if args.synthetic_upsample_size > 0:  #and args.task == "multi_image_folder":
        to_pil = torchvision.transforms.ToPILImage()
        to_tensor = torchvision.transforms.ToTensor()
        resizer = torchvision.transforms.Resize(
            size=(args.synthetic_upsample_size, args.synthetic_upsample_size),
            interpolation=2)

        def extract_patches_2D(img, size):
            patch_H, patch_W = min(img.size(2),
                                   size[0]), min(img.size(3), size[1])
            patches_fold_H = img.unfold(2, patch_H, patch_H)
            if (img.size(2) % patch_H != 0):
                patches_fold_H = torch.cat(
                    (patches_fold_H, img[:, :, -patch_H:, ].permute(
                        0, 1, 3, 2).unsqueeze(2)),
                    dim=2)
                patches_fold_HW = patches_fold_H.unfold(3, patch_W, patch_W)

            if (img.size(3) % patch_W != 0):
                patches_fold_HW = torch.cat(
                    (patches_fold_HW,
                     patches_fold_H[:, :, :, -patch_W:, :].permute(
                         0, 1, 2, 4, 3).unsqueeze(3)),
                    dim=3)

                patches = patches_fold_HW.permute(0, 2, 3, 1, 4, 5).reshape(
                    -1, img.size(1), patch_H, patch_W)

            return patches

        def patch_extractor_lambda(crop):
            crop = crop.unsqueeze(0) if len(crop.shape) < 4 else crop
            return extract_patches_2D(crop, [224, 224])

        aux_transform = lambda x: patch_extractor_lambda(
            to_tensor(resizer(to_pil(to_tensor(x)))))

    loader = get_loader(args,
                        transform=None,
                        sequentially_merge_test=False,
                        aux_transform=aux_transform,
                        postfix="_large",
                        **vars(args))

    # append the image shape to the config & build the VAE
    args.img_shp = loader.img_shp
    model = MultiBatchModule(loader.output_size, checkpoint=args.checkpoint)

    # FP16-ize, cuda-ize and parallelize (if requested)
    model = model.half() if args.half is True else model
    model = model.cuda() if args.cuda is True else model
    model = nn.DataParallel(model) if args.ngpu > 1 else model

    # build the grapher object (tensorboard or visdom)
    # and plot config json to visdom
    if args.visdom_url is not None:
        grapher = Grapher('visdom',
                          env=get_name(),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name())

    grapher.add_text('config',
                     pprint.PrettyPrinter(indent=4).pformat(vars(args)), 0)
    return [model, loader, grapher]
Exemplo n.º 10
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            
            self.train_loader = data_loader.train_loader
            self.valid_loader = data_loader.test_loader

            self.num_train = len(self.train_loader.dataset)
            self.num_valid = len(self.valid_loader.dataset)
        else:
            self.test_loader = data_loader.test_loader
            self.num_test = len(self.test_loader.dataset)
        
        self.num_classes = data_loader.output_size
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.cuda
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(
            config.num_glimpses, config.patch_size,
            config.patch_size, config.glimpse_scale
        )

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size, self.num_patches, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden,
            self.std, self.hidden_size, self.num_classes,
        )

        if self.use_gpu:
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)


        # visdom
        self.use_visdom = config.visdom
        self.visdom_images = config.visdom_images
        self.visdom_env = config.visdom_env

        if self.use_visdom:
            self.grapher = Grapher('visdom',
                            env=self.visdom_env,
                            server=config.visdom_url,
                            port=config.visdom_port)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.optimizer = optim.Adam(
            self.model.parameters(), lr=3e-4,
        )
Exemplo n.º 11
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            
            self.train_loader = data_loader.train_loader
            self.valid_loader = data_loader.test_loader

            self.num_train = len(self.train_loader.dataset)
            self.num_valid = len(self.valid_loader.dataset)
        else:
            self.test_loader = data_loader.test_loader
            self.num_test = len(self.test_loader.dataset)
        
        self.num_classes = data_loader.output_size
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.cuda
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(
            config.num_glimpses, config.patch_size,
            config.patch_size, config.glimpse_scale
        )

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size, self.num_patches, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden,
            self.std, self.hidden_size, self.num_classes,
        )

        if self.use_gpu:
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)


        # visdom
        self.use_visdom = config.visdom
        self.visdom_images = config.visdom_images
        self.visdom_env = config.visdom_env

        if self.use_visdom:
            self.grapher = Grapher('visdom',
                            env=self.visdom_env,
                            server=config.visdom_url,
                            port=config.visdom_port)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.optimizer = optim.Adam(
            self.model.parameters(), lr=3e-4,
        )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (
            torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor
        )

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid)
        )

        for epoch in range(self.start_epoch, self.epochs):

            print(
                '\nEpoch: {}/{} - LR: {:.6f}'.format(
                    epoch+1, self.epochs, self.lr)
            )

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {'epoch': epoch + 1,
                 'model_state': self.model.state_dict(),
                 'optim_state': self.optimizer.state_dict(),
                 'best_valid_acc': self.best_valid_acc,
                 }, is_best
            )

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        softmax_acc = 0

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                # imgs = []
                # imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                glimpses = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    phi, h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store, to look into
                    # locs.append(l_t[0:9])
                    glimpses.append(phi)
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                phi, h_t, l_t, b_t, log_probas, p = self.model(
                    x, l_t, h_t, last=True
                )

                glimpses.append(phi)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]

                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))
                
                # softmax accuracy
                softmax_acc += softmax_accuracy(log_probas, y)

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc-tic)

                pbar.set_description(
                    (
                        "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                            (toc-tic), loss.item(), acc.item()
                        )
                    )
                )
                pbar.update(self.batch_size)

        # Div by the number of batches
        softmax_acc /= i

        # Only per epoch to tensorboard
        if self.use_tensorboard:
            # iteration = epoch*len(self.train_loader) + i
            log_value('train_loss', losses.avg, epoch)
            log_value('train_acc', accs.avg, epoch)

        # Per epoch to visdom
        if self.use_visdom:
            # Do visdom train acc and train loss
            register_plots({'mean': np.array(losses.avg)}, self.grapher, epoch, prefix='train loss')
            register_plots({'mean': np.array(accs.avg)}, self.grapher, epoch, prefix='train accuracy')
            register_plots({'mean': np.array(softmax_acc)}, self.grapher, epoch, prefix='softmax train accuracy')
            self.grapher.show()

        # Todo: code glimse development over time, or location over image
        if self.use_visdom and self.visdom_images:
                phi_tensors = []
                for j, phi in enumerate(glimpses):
                    # stack all phi images from the glimpse list
                    phi_row = phi.cpu().data.detach().view((-1, self.num_patches, self.patch_size, self.patch_size))
                    phi_tensors.append(phi_row.squeeze())
                    register_images(phi_row, 'train glimpse', self.grapher, prefix='train_' + str(epoch) + '_g_' + str(j))
                    self.grapher.show()

                image_grid_tensor = torch.stack(phi_tensors).view(self.num_glimpses * self.batch_size, 1, self.patch_size, self.patch_size)
                register_images(image_grid_tensor, 'train glimpse', self.grapher, prefix='train_' + str(epoch))
                self.grapher.show()

        return losses.avg, accs.avg


    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        softmax_acc = 0

        for i, (x, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for _ in range(self.num_glimpses - 1):
                # forward pass through model
                _, h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            _, h_t, l_t, b_t, log_probas, p = self.model(
                x, l_t, h_t, last=True
            )
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(
                self.M, -1, log_probas.shape[-1]
            )
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(
                self.M, -1, baselines.shape[-1]
            )
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(
                self.M, -1, log_pi.shape[-1]
            )
            log_pi = torch.mean(log_pi, dim=0)

            # This miight be averaged wrong over the repition in x
            softmax_acc += softmax_accuracy(log_probas, y)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

        # Average over the number of batches
        softmax_acc /= i

        # log to tensorboard per epoch instead of per iteration
        if self.use_tensorboard:
            # iteration = epoch*len(self.valid_loader) + i
            log_value('valid_loss', losses.avg, epoch)
            log_value('valid_acc', accs.avg, epoch)

        if self.use_visdom:
            # Do visdom train acc and train loss
            register_plots({'mean': np.array(losses.avg)}, self.grapher, epoch, prefix='validation loss')
            register_plots({'mean': np.array(accs.avg)}, self.grapher, epoch, prefix='validation accuracy')
            register_plots({'mean': np.array(softmax_acc)}, self.grapher, epoch, prefix='softmax validation accuracy')
            self.grapher.show()

        return losses.avg, accs.avg


    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for _, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(
                x, l_t, h_t, last=True
            )

            log_probas = log_probas.view(
                self.M, -1, log_probas.shape[-1]
            )
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print(
            '[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
                correct, self.num_test, perc, error)
        )

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(
                ckpt_path, os.path.join(self.ckpt_dir, filename)
            )

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print(
                "[*] Loaded {} checkpoint @ epoch {} "
                "with best valid acc of {:.3f}".format(
                    filename, ckpt['epoch'], ckpt['best_valid_acc'])
            )
        else:
            print(
                "[*] Loaded {} checkpoint @ epoch {}".format(
                    filename, ckpt['epoch'])
            )