class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]

            image_tmp, _ = iter(self.train_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

            if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR':
                self.num_train = len(self.train_loader.sampler.indices)
                self.num_valid = len(self.valid_loader.sampler.indices)
            elif config.dataset_name == 'ImageNet':
                # the ImageNet cannot be sampled, otherwise this part will be wrong.
                self.num_train = 100000  #len(train_dataset) in data_loader.py, wrong: len(self.train_loader)
                self.num_valid = 10000  #len(self.valid_loader)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)

            image_tmp, _ = iter(self.test_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

        # assign numer of channels and classes of images in this dataset, maybe there is more robust way
        if 'MNIST' in config.dataset_name:
            self.num_channels = 1
            self.num_classes = 10
        elif config.dataset_name == 'ImageNet':
            self.num_channels = 3
            self.num_classes = 1000
        elif config.dataset_name == 'CIFAR':
            self.num_channels = 3
            self.num_classes = 10

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.loss_fun_baseline = config.loss_fun_baseline
        self.loss_fun_action = config.loss_fun_action
        self.weight_decay = config.weight_decay

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.best_train_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq

        if config.use_gpu:
            self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)
        else:
            self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir, exist_ok=True)

        # configure tensorboard logging
        if self.use_tensorboard:
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)
            writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name)

        # build DRAMBUTD model
        self.model = RecurrentAttention(self.patch_size, self.num_channels,
                                        self.image_size, self.std,
                                        self.hidden_size, self.num_classes,
                                        config)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        if config.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay)
        elif config.optimizer == 'ReduceLROnPlateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               patience=self.lr_patience,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(self.model.parameters(),
                                            weight_decay=self.weight_decay)
        elif config.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=3e-4,
                                        weight_decay=self.weight_decay)
        elif config.optimizer == 'AdaBound':
            self.optimizer = adabound.AdaBound(self.model.parameters(),
                                               lr=3e-4,
                                               final_lr=0.1,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Ranger':
            self.optimizer = Ranger(self.model.parameters(),
                                    weight_decay=self.weight_decay)

    def reset(self, x, SM):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)
        #
        h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM)

        # initialize hidden state 1 as 0 vector to avoid the directly classification from context
        h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype)

        cell_state1 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        cell_state2 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best_valid = valid_acc > self.best_valid_acc
            is_best_train = train_acc > self.best_train_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"

            if is_best_train:
                msg1 += " [*]"

            if is_best_valid:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best_valid:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.best_train_acc = max(train_acc, self.best_train_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                    'best_train_acc': self.best_train_acc,
                }, is_best_valid)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x_raw, y) in enumerate(self.train_loader):
                #
                if self.use_gpu:
                    x_raw, y = x_raw.cuda(), y.cuda()

                # detach images and their saliency maps
                x = x_raw[:, 0, ...].unsqueeze(1)
                SM = x_raw[:, 1, ...].unsqueeze(1)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                    x, SM)
                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []

                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                        x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                        SM_local_smooth)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x,
                    l_t,
                    h_t1,
                    h_t2,
                    cell_state1,
                    cell_state2,
                    SM,
                    SM_local_smooth,
                    last=True)

                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                if self.loss_fun_baseline == 'cross_entropy':
                    # cross_entroy_loss need a long, batch x 1 tensor as target but R
                    # also need to be subtracted by the baseline whose size is N x num_glimpse
                    R = (predicted.detach() == y).long()
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)
                    R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
                else:
                    R = (predicted.detach() == y).float()
                    R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)

                # loss_action = F.nll_loss(log_probas, y)
                # loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                #losses.update(loss.data[0], x.size()[0])
                #accs.update(acc.data[0], x.size()[0])
                losses.update(loss.data.item(), x.size()[0])
                accs.update(acc.data.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data.item(), acc.data.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))
                    sio.savemat(self.plot_dir +
                                "data_train_{}.mat".format(epoch + 1),
                                mdict={
                                    'location': locs,
                                    'patch': imgs
                                })

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    writer.add_scalar('Loss/train', losses, iteration)
                    writer.add_scalar('Accuracy/train', accs, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x_raw, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x_raw, y = x_raw.cuda(), y.cuda()

            # detach images and their saliency maps
            x = x_raw[:, 0, ...].unsqueeze(1)
            SM = x_raw[:, 1, ...].unsqueeze(1)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)
            SM = SM.repeat(self.M, 1, 1, 1)
            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # extract the glimpses
            log_pi = []
            baselines = []

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            # store
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            if self.loss_fun_baseline == 'cross_entropy':
                # cross_entroy_loss need a long, batch x 1 tensor as target but R
                # also need to be subtracted by the baseline whose size is N x num_glimpse
                R = (predicted.detach() == y).long()
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)
                R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
            else:
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)

            # compute losses for differentiable modules
            # loss_action = F.nll_loss(log_probas, y)
            # loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data.item(), x.size()[0])
            accs.update(acc.data.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                writer.add_scalar('Accuracy/valid', accs, iteration)
                writer.add_scalar('Loss/valid', losses, iteration)

        return losses.avg, accs.avg

    def choose_loss_fun(self, log_probas, y, baselines, R):
        """
        use disctionary to save function handle
        replacement of swith-case

        be careful of the argument data type and shape!!!
        """
        loss_fun_pool = {
            'mse': F.mse_loss,
            'l1': F.l1_loss,
            'nll': F.nll_loss,
            'smooth_l1': F.smooth_l1_loss,
            'kl_div': F.kl_div,
            'cross_entropy': F.cross_entropy
        }

        return loss_fun_pool[self.loss_fun_action](
            log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R)

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # save images and glimpse location
            locs = []
            imgs = []
            imgs.append(x[0:9])

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

            # dump test data
            if self.use_gpu:
                imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                locs = [l.cpu().data.numpy() for l in locs]
            else:
                imgs = [g.data.numpy().squeeze() for g in imgs]
                locs = [l.data.numpy() for l in locs]

            pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb"))

            pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb"))
            sio.savemat(self.plot_dir + "test_transient.mat",
                        mdict={'location': locs})

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
Example #2
0
            plt.plot(x, lossx, label='loss')
            #plt.plot(x,rmsex,label='rmse')
            plt.legend()
            #changepoint 方便查看tensorboard太麻烦
            plt.savefig(
                '/media/workdir/hujh/hujh-new/huaweirader_baseline/log/demolog/predrnnloss.png'
            )
            plt.close(1)
            #################################################################################
            #changepoint
            if ind % 100 == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    save_dir=save_dir,
                    filename='predrnncheckpoint.pth.tar')
################# valid ########################################################
#if ind % 1000 ==0 and ind > 0:
    val_compareloss = []
    hss = []
    #model.eval()
    if False:
        with torch.no_grad():
            val_rmse = AverageMeter()
            val_losses = AverageMeter()

            if False:
Example #3
0
def main(args):
    """ The main training function.

    Only works for single node (be it single or multi-GPU)

    Parameters
    ----------
    args :
        Parsed arguments
    """
    # setup
    ngpus = torch.cuda.device_count()
    if ngpus == 0:
        raise RuntimeWarning("This will not be able to run on CPU only")

    print(f"Working with {ngpus} GPUs")
    if args.optim.lower() == "ranger":
        # No warm up if ranger optimizer
        args.warm = 0

    current_experiment_time = datetime.now().strftime('%Y%m%d_%T').replace(":", "")
    args.exp_name = f"{'debug_' if args.debug else ''}{current_experiment_time}_" \
                    f"_fold{args.fold if not args.full else 'FULL'}" \
                    f"_{args.arch}_{args.width}" \
                    f"_batch{args.batch_size}" \
                    f"_optim{args.optim}" \
                    f"_{args.optim}" \
                    f"_lr{args.lr}-wd{args.weight_decay}_epochs{args.epochs}_deepsup{args.deep_sup}" \
                    f"_{'fp16' if not args.no_fp16 else 'fp32'}" \
                    f"_warm{args.warm}_" \
                    f"_norm{args.norm_layer}{'_swa' + str(args.swa_repeat) if args.swa else ''}" \
                    f"_dropout{args.dropout}" \
                    f"_warm_restart{args.warm_restart}" \
                    f"{'_' + args.com.replace(' ', '_') if args.com else ''}"
    args.save_folder = pathlib.Path(f"./runs/{args.exp_name}")
    args.save_folder.mkdir(parents=True, exist_ok=True)
    args.seg_folder = args.save_folder / "segs"
    args.seg_folder.mkdir(parents=True, exist_ok=True)
    args.save_folder = args.save_folder.resolve()
    save_args(args)
    t_writer = SummaryWriter(str(args.save_folder))

    # Create model
    print(f"Creating {args.arch}")

    model_maker = getattr(models, args.arch)

    model = model_maker(
        4, 3,
        width=args.width, deep_supervision=args.deep_sup,
        norm_layer=get_norm_layer(args.norm_layer), dropout=args.dropout)

    print(f"total number of trainable parameters {count_parameters(model)}")

    if args.swa:
        # Create the average model
        swa_model = model_maker(
            4, 3,
            width=args.width, deep_supervision=args.deep_sup,
            norm_layer=get_norm_layer(args.norm_layer))
        for param in swa_model.parameters():
            param.detach_()
        swa_model = swa_model.cuda()
        swa_model_optim = WeightSWA(swa_model)

    if ngpus > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    print(model)
    model_file = args.save_folder / "model.txt"
    with model_file.open("w") as f:
        print(model, file=f)

    criterion = EDiceLoss().cuda()
    metric = criterion.metric
    print(metric)

    rangered = False  # needed because LR scheduling scheme is different for this optimizer
    if args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay, eps=1e-4)


    elif args.optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                                    nesterov=True)

    elif args.optim == "adamw":
        print(f"weight decay argument will not be used. Default is 11e-2")
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    elif args.optim == "ranger":
        optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        rangered = True

    # optionally resume from a checkpoint
    if args.resume:
        reload_ckpt(args, model, optimizer)

    if args.debug:
        args.epochs = 2
        args.warm = 0
        args.val = 1

    if args.full:
        train_dataset, bench_dataset = get_datasets(args.seed, args.debug, full=True)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)

    else:

        train_dataset, val_dataset, bench_dataset = get_datasets(args.seed, args.debug, fold_number=args.fold)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=max(1, args.batch_size // 2), shuffle=False,
            pin_memory=False, num_workers=args.workers, collate_fn=determinist_collate)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)
        print("Val dataset number of batch:", len(val_loader))

    print("Train dataset number of batch:", len(train_loader))

    # create grad scaler
    scaler = GradScaler()

    # Actual Train loop

    best = np.inf
    print("start warm-up now!")
    if args.warm != 0:
        tot_iter_train = len(train_loader)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda cur_iter: (1 + cur_iter) / (tot_iter_train * args.warm))

    patients_perf = []

    if not args.resume:
        for epoch in range(args.warm):
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, scheduler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch,
                                           t_writer, save_folder=args.save_folder,
                                           no_fp16=args.no_fp16)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

    if args.warm_restart:
        print('Total number of epochs should be divisible by 30, else it will do odd things')
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30, eta_min=1e-7)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               args.epochs + 30 if not rangered else round(
                                                                   args.epochs * 0.5))
    print("start training now!")
    if args.swa:
        # c = 15, k=3, repeat = 5
        c, k, repeat = 30, 3, args.swa_repeat
        epochs_done = args.epochs
        reboot_lr = 0
        if args.debug:
            c, k, repeat = 2, 1, 2

    for epoch in range(args.start_epoch + args.warm, args.epochs + args.warm):
        try:
            # do_epoch for one epoch
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                           epoch,
                                           t_writer,
                                           save_folder=args.save_folder,
                                           no_fp16=args.no_fp16, patients_perf=patients_perf)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

                if validation_loss < best:
                    best = validation_loss
                    model_dict = model.state_dict()
                    save_checkpoint(
                        dict(
                            epoch=epoch, arch=args.arch,
                            state_dict=model_dict,
                            optimizer=optimizer.state_dict(),
                            scheduler=scheduler.state_dict(),
                        ),
                        save_folder=args.save_folder, )

                ts = time.perf_counter()
                print(f"Val epoch done in {ts - te} s")

            if args.swa:
                if (args.epochs - epoch - c) == 0:
                    reboot_lr = optimizer.param_groups[0]['lr']

            if not rangered:
                scheduler.step()
                print("scheduler stepped!")
            else:
                if epoch / args.epochs > 0.5:
                    scheduler.step()
                    print("scheduler stepped!")

        except KeyboardInterrupt:
            print("Stopping training loop, doing benchmark")
            break

    if args.swa:
        swa_model_optim.update(model)
        print("SWA Model initialised!")
        for i in range(repeat):
            optimizer = torch.optim.Adam(model.parameters(), args.lr / 2, weight_decay=args.weight_decay)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, c + 10)
            for swa_epoch in range(c):
                # do_epoch for one epoch
                ts = time.perf_counter()
                model.train()
                swa_model.train()
                current_epoch = epochs_done + i * c + swa_epoch
                training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer,
                                     current_epoch, t_writer,
                                     scaler, no_fp16=args.no_fp16, patients_perf=patients_perf)
                te = time.perf_counter()
                print(f"Train Epoch done in {te - ts} s")

                t_writer.add_scalar(f"SummaryLoss/train", training_loss, current_epoch)

                # update every k epochs and val:
                print(f"cycle number: {i}, swa_epoch: {swa_epoch}, total_cycle_to_do {repeat}")
                if (swa_epoch + 1) % k == 0:
                    swa_model_optim.update(model)
                    if not args.full:
                        model.eval()
                        swa_model.eval()
                        with torch.no_grad():
                            validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                                   current_epoch,
                                                   t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16)
                            swa_model_loss = step(val_loader, swa_model, criterion, metric, args.deep_sup, optimizer,
                                                  current_epoch,
                                                  t_writer, swa=True, save_folder=args.save_folder,
                                                  no_fp16=args.no_fp16)

                        t_writer.add_scalar(f"SummaryLoss/val", validation_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/swa", swa_model_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit_swa", swa_model_loss - training_loss, current_epoch)
                scheduler.step()
        epochs_added = c * repeat
        save_checkpoint(
            dict(
                epoch=args.epochs + epochs_added, arch=args.arch,
                state_dict=swa_model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )
    else:
        save_checkpoint(
            dict(
                epoch=args.epochs, arch=args.arch,
                state_dict=model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )

    try:
        df_individual_perf = pd.DataFrame.from_records(patients_perf)
        print(df_individual_perf)
        df_individual_perf.to_csv(f'{str(args.save_folder)}/patients_indiv_perf.csv')
        reload_ckpt_bis(f'{str(args.save_folder)}/model_best.pth.tar', model)
        generate_segmentations(bench_loader, model, t_writer, args)
    except KeyboardInterrupt:
        print("Stopping right now!")
Example #4
0
def main(args, logger):
    writer = SummaryWriter(args.subTensorboardDir)
    model = Vgg().to(device)
    trainSet = Lung(rootDir=args.dataDir, mode='train', size=args.inputSize)
    valSet = Lung(rootDir=args.dataDir, mode='test', size=args.inputSize)
    trainDataloader = DataLoader(trainSet,
                                 batch_size=args.batchSize,
                                 drop_last=True,
                                 shuffle=True,
                                 pin_memory=False,
                                 num_workers=args.numWorkers)
    valDataloader = DataLoader(valSet,
                               batch_size=args.valBatchSize,
                               drop_last=False,
                               shuffle=False,
                               pin_memory=False,
                               num_workers=args.numWorkers)
    criterion = nn.CrossEntropyLoss()
    optimizer = Ranger(model.parameters(), lr=args.lr)
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.apexType)
    iter = 0
    runningLoss = []
    for epoch in range(args.epoch):
        if epoch != 0 and epoch % args.evalFrequency == 0:
            f1, acc = eval(model, valDataloader, logger)
            writer.add_scalars('f1_acc', {'f1': f1,
                                          'acc': acc}, iter)

        if epoch != 0 and epoch % args.saveFrequency == 0:
            modelName = osp.join(args.subModelDir, 'out_{}.pt'.format(epoch))
            # 防止分布式训练保存失败
            stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict()
            torch.save(stateDict, modelName)
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict()
            }
            torch.save(checkpoint, modelName)

        for img, lb, _ in trainDataloader:
            # array = np.array(img)
            # for i in range(array.shape[0]):
            #     plt.imshow(array[i, 0, ...], cmap='gray')
            #     plt.show()
            iter += 1
            img, lb = img.to(device), lb.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            loss = criterion(outputs.squeeze(), lb.long())
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            # loss.backward()
            optimizer.step()
            runningLoss.append(loss.item())

            if iter % args.msgFrequency == 0:
                avgLoss = np.mean(runningLoss)
                runningLoss = []
                lr = optimizer.param_groups[0]['lr']
                logger.info(f'epoch: {epoch} / {args.epoch}, '
                            f'iter: {iter} / {len(trainDataloader) * args.epoch}, '
                            f'lr: {lr}, '
                            f'loss: {avgLoss:.4f}')
                writer.add_scalar('loss', avgLoss, iter)

    eval(model, valDataloader, logger)
    modelName = osp.join(args.subModelDir, 'final.pth')
    stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict()
    torch.save(stateDict, modelName)
Example #5
0
def train_model(dataset=dataset,
                save_dir=save_dir,
                num_classes=num_classes,
                lr=lr,
                num_epochs=nEpochs,
                save_epoch=snapshot,
                useTest=useTest,
                test_interval=nTestInterval):
    """
        Args:
            num_classes (int): Number of classes in the data
            num_epochs (int, optional): Number of epochs to train for.
    """
    file = open('run/log.txt', 'w')

    if modelName == 'C3D':
        model = C3D(num_class=num_classes)
        model.my_load_pretrained_weights('saved_model/c3d.pickle')
        train_params = model.parameters()
        # train_params = [{'params': get_1x_lr_params(model), 'lr': lr},
        #                 {'params': get_10x_lr_params(model), 'lr': lr * 10}]
    # elif modelName == 'R2Plus1D':
    #     model = R2Plus1D_model.R2Plus1DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
    #     train_params = [{'params': R2Plus1D_model.get_1x_lr_params(model), 'lr': lr},
    #                     {'params': R2Plus1D_model.get_10x_lr_params(model), 'lr': lr * 10}]
    # elif modelName == 'R3D':
    #     model = R3D_model.R3DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
    #     train_params = model.parameters()
    elif modelName == 'Res3D':
        # model = Resnet(num_classes=num_classes, block=resblock, layers=[3, 4, 6, 3])
        # train_params=model.parameters()
        model = generate_model(50)
        model = load_pretrained_model(model,
                                      './saved_model/r3d50_K_200ep.pth',
                                      n_finetune_classes=num_classes)
        train_params = model.parameters()
    else:
        print('We only implemented C3D and R2Plus1D models.')
        raise NotImplementedError
    criterion = nn.CrossEntropyLoss(
    )  # standard crossentropy loss for classification
    # optimizer = torch.optim.Adam(train_params, lr=lr, betas=(0.9, 0.999), weight_decay=1e-5,
    #                              amsgrad=True)
    optimizer = Ranger(train_params,
                       lr=lr,
                       betas=(.95, 0.999),
                       weight_decay=5e-4)
    print('use ranger')

    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=32,
                                  eta_min=0,
                                  last_epoch=-1)
    # optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
    #                                       gamma=0.1)  # the scheduler divides the lr by 10 every 10 epochs
    if resume_epoch == 0:
        print("Training {} from scratch...".format(modelName))
    else:
        checkpoint = torch.load(os.path.join(
            save_dir, 'models',
            saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'),
                                map_location=lambda storage, loc: storage
                                )  # Load all tensors onto the CPU
        print("Initializing weights from: {}...".format(
            os.path.join(
                save_dir, 'models',
                saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar')))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['opt_dict'])

    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    # model.to(device)
    if torch.cuda.is_available():
        model = model.cuda()
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        model = nn.DataParallel(model)
        criterion.cuda()

    # log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    log_dir = os.path.join(save_dir)
    writer = SummaryWriter(log_dir=log_dir)

    print('Training model on {} dataset...'.format(dataset))
    train_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                               split='train',
                                               clip_len=16),
                                  batch_size=8,
                                  shuffle=True,
                                  num_workers=8)
    val_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                             split='validation',
                                             clip_len=16),
                                batch_size=8,
                                num_workers=8)
    test_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                              split='test',
                                              clip_len=16),
                                 batch_size=8,
                                 num_workers=8)

    trainval_loaders = {'train': train_dataloader, 'val': val_dataloader}
    trainval_sizes = {
        x: len(trainval_loaders[x].dataset)
        for x in ['train', 'val']
    }
    test_size = len(test_dataloader.dataset)
    # my_smooth={'0': 0.88, '1': 0.95, '2': 0.96, '3': 0.79, '4': 0.65, '5': 0.89, '6': 0.88}
    for epoch in range(resume_epoch, num_epochs):
        # each epoch has a training and validation step
        for phase in ['train', 'val']:
            start_time = timeit.default_timer()

            # reset the running loss and corrects
            running_loss = 0.0
            running_corrects = 0.0
            # set model to train() or eval() mode depending on whether it is trained
            # or being validated. Primarily affects layers such as BatchNorm or Dropout.
            if phase == 'train':
                # scheduler.step() is to be called once every epoch during training
                # scheduler.step()
                model.train()
            else:
                model.eval()

            for inputs, labels in tqdm(trainval_loaders[phase]):
                # move inputs and labels to the device the training is taking place on
                inputs = Variable(inputs, requires_grad=True).to(device)
                labels = Variable(labels).to(device)
                # inputs = inputs.cuda(non_blocking=True)
                # labels = labels.cuda(non_blocking=True)
                optimizer.zero_grad()

                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                probs = nn.Softmax(dim=1)(outputs)
                # the size of output is [bs , 7]
                preds = torch.max(probs, 1)[1]
                # preds is the index of maxnum of output
                # print(outputs)
                # print(torch.max(outputs, 1))

                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    scheduler.step(loss)
                # for name, parms in model.named_parameters():
                #     print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
                #           ' -->grad_value:', parms.grad)
                #     print('-->name:', name, ' -->grad_value:', parms.grad)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                print('\ntemp/label:{}/{}'.format(preds[0], labels[0]))

            epoch_loss = running_loss / trainval_sizes[phase]
            epoch_acc = running_corrects.double() / trainval_sizes[phase]

            if phase == 'train':
                writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch)
            else:
                writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch)

            print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(
                phase, epoch + 1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")
            file.write("\n[{}] Epoch: {}/{} Loss: {} Acc: {}".format(
                phase, epoch + 1, nEpochs, epoch_loss, epoch_acc))

        if epoch % save_epoch == (save_epoch - 1):
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'opt_dict': optimizer.state_dict(),
                },
                os.path.join(save_dir,
                             saveName + '_epoch-' + str(epoch) + '.pth.tar'))
            print("Save model at {}\n".format(
                os.path.join(save_dir,
                             saveName + '_epoch-' + str(epoch) + '.pth.tar')))

        if useTest and epoch % test_interval == (test_interval - 1):
            model.eval()
            start_time = timeit.default_timer()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    outputs = model(inputs)
                probs = nn.Softmax(dim=1)(outputs)
                preds = torch.max(probs, 1)[1]
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / test_size
            epoch_acc = running_corrects.double() / test_size

            writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
            writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)

            print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(
                epoch + 1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")
            file.write("\n[test] Epoch: {}/{} Loss: {} Acc: {}\n".format(
                epoch + 1, nEpochs, epoch_loss, epoch_acc))
    writer.close()
    file.close()
Example #6
0
def main():
    global best_acc, mean, std, scale

    args = parse_args()
    args.mean, args.std, args.scale = mean, std, scale
    args.is_master = args.local_rank == 0

    if args.deterministic:
        cudnn.deterministic = True
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.is_master:
        print("opt_level = {}".format(args.opt_level))
        print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
              type(args.keep_batchnorm_fp32))
        print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
        print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
        print(f"Distributed Training Enabled: {args.distributed}")

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        # Scale learning rate based on global batch size
        # args.lr *= args.batch_size * args.world_size / 256

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    model = models.ResNet18(args.num_patches, args.num_angles)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    optimiser = Ranger(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss().cuda()

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimiser = amp.initialize(
        model,
        optimiser,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            global best_acc
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_acc = checkpoint['best_acc']
                args.poisson_rate = checkpoint["poisson_rate"]
                model.load_state_dict(checkpoint['state_dict'])
                optimiser.load_state_dict(checkpoint['optimiser'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    train_dir = os.path.join(args.data, 'train')
    val_dir = os.path.join(args.data, 'val')

    crop_size = 225
    val_size = 256

    imagenet_train = datasets.ImageFolder(
        root=train_dir,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
        ]))
    train_dataset = SSLTrainDataset(imagenet_train, args.num_patches,
                                    args.num_angles, args.poisson_rate)
    imagenet_val = datasets.ImageFolder(root=val_dir,
                                        transform=transforms.Compose([
                                            transforms.Resize(val_size),
                                            transforms.CenterCrop(crop_size),
                                        ]))
    val_dataset = SSLValDataset(imagenet_val, args.num_patches,
                                args.num_angles)

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler,
                              collate_fn=fast_collate)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            sampler=val_sampler,
                            collate_fn=fast_collate)

    if args.evaluate:
        val_loss, val_acc = apex_validate(val_loader, model, criterion, args)
        utils.logger.info(f"Val Loss = {val_loss}, Val Accuracy = {val_acc}")
        return

    # Create dir to save model and command-line args
    if args.is_master:
        model_dir = time.ctime().replace(" ", "_").replace(":", "_")
        model_dir = os.path.join("models", model_dir)
        os.makedirs(model_dir, exist_ok=True)
        with open(os.path.join(model_dir, "args.json"), "w") as f:
            json.dump(args.__dict__, f, indent=2)
        writer = SummaryWriter()

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_loss, train_acc = apex_train(train_loader, model, criterion,
                                           optimiser, args, epoch)

        # evaluate on validation set
        val_loss, val_acc = apex_validate(val_loader, model, criterion, args)

        if (epoch + 1) % args.learn_prd == 0:
            utils.adj_poisson_rate(train_loader, args)

        # remember best Acc and save checkpoint
        if args.is_master:
            is_best = val_acc > best_acc
            best_acc = max(val_acc, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimiser': optimiser.state_dict(),
                    "poisson_rate": args.poisson_rate
                }, is_best, model_dir)

            writer.add_scalars("Loss", {
                "train_loss": train_loss,
                "val_loss": val_loss
            }, epoch)
            writer.add_scalars("Accuracy", {
                "train_acc": train_acc,
                "val_acc": val_acc
            }, epoch)
            writer.add_scalar("Poisson_Rate", train_loader.dataset.pdist.rate,
                              epoch)