예제 #1
0
    def print_model_structure(self):

        img_size = self.img_size

        # count model parameters
        nparams_g = count_model_params(self.G)
        nparams_d = count_model_params(self.D)

        with open(self.log_dir+'/model_structure_{}x{}.txt'.format(img_size, img_size),'a') as f:
            print('--------------------------------------------------', file=f)
            print('Sequences in Dataset: ', len(self.dataset), file=f)
            print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
            print('Phase: ', self.phase, file=f)
            print('Number of Generator`s model parameters: ', file=f)
            print(nparams_g, file=f)
            print('Number of Discriminator`s model parameters: ', file=f)
            print(nparams_d, file=f)
            print('--------------------------------------------------', file=f)
            print('New Generator structure: ', file=f)
            print(self.G.module, file=f)
            print('--------------------------------------------------', file=f)
            print('New Discriminator structure: ', file=f)
            print(self.D.module, file=f)
            print('--------------------------------------------------', file=f)
            print(' ... models are being updated ... ')
            print(' ... saving updated model strutures to {}'.format(f))
예제 #2
0
    def _build_train(self):
        tf.logging.info("-" * 80)
        tf.logging.info("Build train graph")
        logits = self._model(self.x_train,
                             is_training=True,
                             reuse=tf.AUTO_REUSE)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)
        self.loss = tf.reduce_mean(log_probs)

        if self.use_aux_heads:
            log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.aux_logits, labels=self.y_train)
            self.aux_loss = tf.reduce_mean(log_probs)
            train_loss = self.loss + 0.4 * self.aux_loss
        else:
            train_loss = self.loss

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)
        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)

        tf_variables = [
            var for var in tf.trainable_variables()
            if (var.name.startswith(self.name) and "aux_head" not in var.name)
        ]
        self.num_vars = count_model_params(tf_variables)
        tf.logging.info("Model has {0} params".format(self.num_vars))

        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
                train_loss,
                tf_variables,
                self.global_step,
                self.num_train_steps,
                clip_mode=self.clip_mode,
                grad_bound=self.grad_bound,
                l2_reg=self.l2_reg,
                lr_init=self.lr_init,
                lr_dec_start=self.lr_dec_start,
                lr_dec_every=self.lr_dec_every,
                lr_dec_rate=self.lr_dec_rate,
                lr_cosine=self.lr_cosine,
                lr_max=self.lr_max,
                lr_min=self.lr_min,
                lr_T_0=self.lr_T_0,
                lr_T_mul=self.lr_T_mul,
                num_train_batches=self.num_train_batches,
                optim_algo=self.optim_algo,
                sync_replicas=self.sync_replicas,
                num_aggregate=self.num_aggregate,
                num_replicas=self.num_replicas)
예제 #3
0
    def _build_train(self):
        print("-" * 80)
        print("Build train graph")
        logits = self._model(self.x_train, is_training=True)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)  ## self.x_train [32,3,32,32] //  batch size is 32!

        self.loss = tf.reduce_mean(log_probs)  ## loss function for training

        if self.use_aux_heads:
            log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.aux_logits, labels=self.y_train)
            self.aux_loss = tf.reduce_mean(log_probs)
            train_loss = self.loss + 0.4 * self.aux_loss
        else:
            train_loss = self.loss

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)

        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)  # we should divide self.train_acc by batch_size 32

        tf_variables = [
            var for var in tf.trainable_variables() if (
                    var.name.startswith(self.name) and "aux_head" not in var.name)]
        self.num_vars = count_model_params(tf_variables)
        print("Model has {0} params".format(self.num_vars))

        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            train_loss,
            tf_variables,
            self.global_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            lr_cosine=self.lr_cosine,
            lr_max=self.lr_max,
            lr_min=self.lr_min,
            lr_T_0=self.lr_T_0,
            lr_T_mul=self.lr_T_mul,
            num_train_batches=self.num_train_batches,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)
예제 #4
0
    def _build_train(self):
        print("Build train graph")
        logits = self._model(self.x_train, True)
        log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.y_train)
        self.loss = tf.reduce_mean(log_probs)

        self.train_preds = tf.argmax(logits, axis=1)
        self.train_preds = tf.to_int32(self.train_preds)
        self.train_acc = tf.equal(self.train_preds, self.y_train)
        self.train_acc = tf.to_int32(self.train_acc)
        self.train_acc = tf.reduce_sum(self.train_acc)

        tf_variables = [
            var for var in tf.trainable_variables()
            if var.name.startswith(self.name)
        ]
        self.num_vars = count_model_params(tf_variables)
        print("-" * 80)
        for var in tf_variables:
            print(var)

        self.global_step = tf.Variable(0,
                                       dtype=tf.int32,
                                       trainable=False,
                                       name="global_step")
        self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
            self.loss,
            tf_variables,
            self.global_step,
            clip_mode=self.clip_mode,
            grad_bound=self.grad_bound,
            l2_reg=self.l2_reg,
            lr_init=self.lr_init,
            lr_dec_start=self.lr_dec_start,
            lr_dec_every=self.lr_dec_every,
            lr_dec_rate=self.lr_dec_rate,
            optim_algo=self.optim_algo,
            sync_replicas=self.sync_replicas,
            num_aggregate=self.num_aggregate,
            num_replicas=self.num_replicas)
예제 #5
0
def evaluate_pred(config):

    # define directories
    model_name = config.model

    test_data_root = config.data_root
    if config.deep_pred > 1:
        test_dir = config.test_dir + '/' + config.experiment_name + '/deep-pred{}/'.format(
            config.deep_pred) + model_name
    else:
        test_dir = config.test_dir + '/' + config.experiment_name + '/pred/' + model_name
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    sample_dir = test_dir + '/samples'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)

    nframes_in = config.nframes_in
    nframes_pred = config.nframes_pred * config.deep_pred
    nframes = nframes_in + nframes_pred
    img_size = int(config.resl)
    nworkers = 4

    # load model
    if config.model == 'FutureGAN':
        ckpt = torch.load(config.model_path)
        # model structure
        G = ckpt['G_structure']
        # load model parameters
        G.load_state_dict(ckpt['state_dict'])
        G.eval()
        G = G.module.model
        if use_cuda:
            G = G.cuda()
        print(' ... loading FutureGAN`s FutureGenerator from checkpoint: {}'.
              format(config.model_path))

    # load test dataset
    transform = transforms.Compose([
        transforms.Resize(size=(img_size, img_size),
                          interpolation=Image.NEAREST),
        transforms.ToTensor(),
    ])
    if config.model == 'FutureGAN' or config.model == 'CopyLast':
        dataset_gt = VideoFolder(video_root=test_data_root,
                                 video_ext=config.ext,
                                 nframes=nframes,
                                 loader=video_loader,
                                 transform=transform)
        dataloader_gt = DataLoader(
            dataset=dataset_gt,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_gt),
            num_workers=nworkers)
    else:
        dataset_gt = VideoFolder(video_root=test_data_root + '/in_gt',
                                 video_ext=config.ext,
                                 nframes=nframes,
                                 loader=video_loader,
                                 transform=transform)
        dataset_pred = VideoFolder(video_root=test_data_root + '/in_pred',
                                   video_ext=config.ext,
                                   nframes=nframes,
                                   loader=video_loader,
                                   transform=transform)
        dataloader_pred = DataLoader(
            dataset=dataset_pred,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_pred),
            num_workers=nworkers)
        dataloader_gt = DataLoader(
            dataset=dataset_gt,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_gt),
            num_workers=nworkers)
        data_iter_pred = iter(dataloader_pred)
    test_len = len(dataset_gt)
    data_iter_gt = iter(dataloader_gt)

    # save model structure to file
    if config.model == 'FutureGAN':
        # count model parameters
        nparams_g = count_model_params(G)
        with open(
                test_dir +
                '/model_structure_{}x{}.txt'.format(img_size, img_size),
                'w') as f:
            print('--------------------------------------------------', file=f)
            print('Sequences in test dataset: ', len(dataset_gt), file=f)
            print('Number of model parameters: ', file=f)
            print(nparams_g, file=f)
            print('--------------------------------------------------', file=f)
            print('Model structure: ', file=f)
            print(G, file=f)
            print('--------------------------------------------------', file=f)
            print(
                ' ... FutureGAN`s FutureGenerator has been loaded successfully from checkpoint ... '
            )
            print(' ... saving model struture to {}'.format(f))

    # save test configuration
    with open(test_dir + '/eval_config.txt', 'w') as f:
        print('------------- test configuration -------------', file=f)
        for l, m in vars(config).items():
            print(('{}: {}').format(l, m), file=f)
        print(' ... loading test configuration ... ')
        print(' ... saving test configuration {}'.format(f))

    # define tensors
    if config.model == 'FutureGAN':
        print(' ... testing FutureGAN ...')
        if config.deep_pred > 1:
            print(
                ' ... recursively predicting {}x{} future frames from {} input frames ...'
                .format(config.deep_pred, config.nframes_pred, nframes_in))
        else:
            print(' ... predicting {} future frames from {} input frames ...'.
                  format(nframes_pred, nframes_in))
    z = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_in, img_size,
                          img_size))
    z_in = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_in, img_size,
                          img_size))
    x_pred = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_pred, img_size,
                          img_size))
    x = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes, img_size,
                          img_size))
    x_eval = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_pred, img_size,
                          img_size))

    # define tensors for evaluation
    if config.metrics is not None:
        print(' ... evaluating {} ...'.format(model_name))
        if 'ms_ssim' in config.metrics and img_size < 32:
            raise ValueError(
                'For calculating `ms_ssim`, your dataset must consist of images at least of size 32x32!'
            )

        metrics_values = {}
        for metric_name in config.metrics:
            metrics_values['{}_frames'.format(metric_name)] = torch.zeros_like(
                torch.FloatTensor(test_len, nframes_pred))
            metrics_values['{}_avg'.format(metric_name)] = torch.zeros_like(
                torch.FloatTensor(test_len, 1))
            print(' ... calculating {} ...'.format(metric_name))

    # test loop
    if config.metrics is not None:
        metrics_i_video = {}
        for metric_name in config.metrics:
            metrics_i_video['{}_i_video'.format(metric_name)] = 0

    i_save_video = 1
    i_save_gif = 1

    for step in tqdm(range(len(data_iter_gt))):

        # input frames
        x.data = next(data_iter_gt)
        x_eval.data = x.data[:, :, nframes_in:, :, :]
        z.data = x.data[:, :, :nframes_in, :, :]

        if use_cuda:
            x = x.cuda()
            x_eval = x_eval.cuda()
            z = z.cuda()
            x_pred = x_pred.cuda()

        # predict video frames
        # !!! TODO !!! for deep_pred > 1: correctly implemented only if nframes_in == nframes_pred
        if config.model == 'FutureGAN':
            z_in.data = z.data
            for i_deep_pred in range(0, config.deep_pred):
                x_pred[:z_in.size(0), :, i_deep_pred *
                       config.nframes_pred:(i_deep_pred *
                                            config.nframes_pred) +
                       config.nframes_pred, :, :] = G(z_in).detach()
                z_in.data = x_pred.data[:, :,
                                        i_deep_pred * config.nframes_pred:
                                        (i_deep_pred * config.nframes_pred) +
                                        config.nframes_pred, :, :]

        elif config.model == 'CopyLast':
            for i_baseline_frame in range(x_pred.size(2)):
                x_pred.data[:x.size(0), :,
                            i_baseline_frame, :, :] = x.data[:, :, nframes_in -
                                                             1, :, :]

        else:
            x_pred.data = next(data_iter_pred)[:x.size(0), :,
                                               nframes_in:, :, :]

        # calculate eval statistics
        if config.metrics is not None:
            for metric_name in config.metrics:
                calculate_metric = getattr(eval_metrics,
                                           'calculate_{}'.format(metric_name))

                for i_batch in range(x.size(0)):
                    for i_frame in range(nframes_pred):
                        metrics_values['{}_frames'.format(metric_name)][
                            metrics_i_video['{}_i_video'.format(metric_name)],
                            i_frame] = calculate_metric(
                                x_pred[i_batch, :, i_frame, :, :],
                                x_eval[i_batch, :, i_frame, :, :])
                        metrics_values['{}_avg'.format(metric_name)][
                            metrics_i_video['{}_i_video'.format(
                                metric_name)]] = torch.mean(
                                    metrics_values['{}_frames'.format(
                                        metric_name)][metrics_i_video[
                                            '{}_i_video'.format(metric_name)]])
                    metrics_i_video['{}_i_video'.format(
                        metric_name
                    )] = metrics_i_video['{}_i_video'.format(metric_name)] + 1

        # save frames
        if config.save_frames_every is not 0 and config.model == 'FutureGAN':
            if step % config.save_frames_every == 0 or step == 0:
                for i_save_batch in range(x.size(0)):
                    if not os.path.exists(
                            sample_dir +
                            '/in_gt/video{:04d}'.format(i_save_video)):
                        os.makedirs(sample_dir +
                                    '/in_gt/video{:04d}'.format(i_save_video))
                    if not os.path.exists(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_video)):
                        os.makedirs(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_video))
                    for i_save_z in range(z.size(2)):
                        save_image_grid(
                            z.data[i_save_batch, :,
                                   i_save_z, :, :].unsqueeze(0), sample_dir +
                            '/in_gt/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_z + 1,
                                    img_size, img_size), img_size, 1)
                        save_image_grid(
                            z.data[i_save_batch, :,
                                   i_save_z, :, :].unsqueeze(0), sample_dir +
                            '/in_pred/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_z + 1,
                                    img_size, img_size), img_size, 1)
                    for i_save_x_pred in range(x_pred.size(2)):
                        save_image_grid(
                            x_eval.data[i_save_batch, :,
                                        i_save_x_pred, :, :].unsqueeze(0),
                            sample_dir +
                            '/in_gt/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_x_pred +
                                    1 + nframes_in, img_size, img_size),
                            img_size, 1)
                        save_image_grid(
                            x_pred.data[i_save_batch, :,
                                        i_save_x_pred, :, :].unsqueeze(0),
                            sample_dir +
                            '/in_pred/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_x_pred +
                                    1 + nframes_in, img_size, img_size),
                            img_size, 1)
                    i_save_video = i_save_video + 1

        # save gifs
        if config.save_gif_every is not 0:
            if step % config.save_gif_every == 0 or step == 0:
                for i_save_batch in range(x.size(0)):
                    if not os.path.exists(
                            sample_dir +
                            '/in_gt/video{:04d}'.format(i_save_gif)):
                        os.makedirs(sample_dir +
                                    '/in_gt/video{:04d}'.format(i_save_gif))
                    if not os.path.exists(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_gif)):
                        os.makedirs(sample_dir +
                                    '/in_pred/video{:04d}'.format(i_save_gif))
                    frames = []
                    for i_save_z in range(z.size(2)):
                        frames.append(
                            get_image_grid(
                                z.data[i_save_batch, :,
                                       i_save_z, :, :].unsqueeze(0), img_size,
                                1, config.in_border, config.npx_border))
                    for i_save_x_pred in range(x_pred.size(2)):
                        frames.append(
                            get_image_grid(
                                x_eval.data[i_save_batch, :,
                                            i_save_x_pred, :, :].unsqueeze(0),
                                img_size, 1, config.out_border,
                                config.npx_border))
                    imageio.mimsave(
                        sample_dir +
                        '/in_gt/video{:04d}/video{:04d}_R{}x{}.gif'.format(
                            i_save_gif, i_save_gif, img_size, img_size),
                        frames)
                    frames = []
                    for i_save_z in range(z.size(2)):
                        frames.append(
                            get_image_grid(
                                z.data[i_save_batch, :,
                                       i_save_z, :, :].unsqueeze(0), img_size,
                                1, config.in_border, config.npx_border))
                    for i_save_x_pred in range(x_pred.size(2)):
                        frames.append(
                            get_image_grid(
                                x_pred.data[i_save_batch, :,
                                            i_save_x_pred, :, :].unsqueeze(0),
                                img_size, 1, config.out_border,
                                config.npx_border))
                    imageio.mimsave(
                        sample_dir +
                        '/in_pred/video{:04d}/video{:04d}_R{}x{}.gif'.format(
                            i_save_gif, i_save_gif, img_size, img_size),
                        frames)
                    i_save_gif = i_save_gif + 1

    if config.save_frames_every is not 0 and config.model == 'FutureGAN':
        print(' ... saving video frames to dir: {}'.format(sample_dir))
        if config.save_gif_every is not 0:
            print(' ... saving gifs to dir: {}'.format(sample_dir))

    # calculate and save mean eval statistics
    if config.metrics is not None:
        metrics_mean_values = {}
        for metric_name in config.metrics:
            metrics_mean_values['{}_frames'.format(metric_name)] = torch.mean(
                metrics_values['{}_frames'.format(metric_name)], 0)
            metrics_mean_values['{}_avg'.format(metric_name)] = torch.mean(
                metrics_values['{}_avg'.format(metric_name)], 0)
            torch.save(
                metrics_mean_values['{}_frames'.format(metric_name)],
                os.path.join(test_dir, '{}_frames.pt'.format(metric_name)))
            torch.save(metrics_mean_values['{}_avg'.format(metric_name)],
                       os.path.join(test_dir, '{}_avg.pt'.format(metric_name)))

        print(' ... saving evaluation statistics to dir: {}'.format(test_dir))
예제 #6
0
    def _model(self, images, is_training, reuse=False):
        """Compute the logits given the images."""
        if self.fixed_arc is None:
            is_training = True

        with tf.variable_scope(self.name, reuse=reuse):
            with tf.variable_scope("stem_conv"):
                w = create_weight("w",
                                  [3, 3, self.channel, self.out_filters * 3])
                x = tf.nn.conv2d(images,
                                 w, [1, 1, 1, 1],
                                 "SAME",
                                 data_format=self.data_format)
                x = batch_norm(x, is_training, data_format=self.data_format)
            if self.data_format == "NHWC":
                split_axis = 3
            elif self.data_format == "NCHW":
                split_axis = 1
            else:
                raise ValueError("Unknown data_format '{0}'".format(
                    self.data_format))
            layers = [x, x]

            # building layers in the micro space
            out_filters = self.out_filters
            for layer_id in range(self.num_layers + 2):
                with tf.variable_scope("layer_{0}".format(layer_id)):
                    if layer_id not in self.pool_layers:
                        if self.fixed_arc is None:
                            x = self._enas_layer(layer_id, layers,
                                                 self.normal_arc, out_filters)

                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.normal_arc,
                                out_filters,
                                1,
                                is_training,
                                normal_or_reduction_cell="normal")
                    else:
                        out_filters *= 2
                        if self.fixed_arc is None:
                            x = self._factorized_reduction(
                                x, out_filters, 2, is_training)
                            layers = [layers[-1], x]
                            x = self._enas_layer(layer_id, layers,
                                                 self.reduce_arc, out_filters)
                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.reduce_arc,
                                out_filters,
                                2,
                                is_training,
                                normal_or_reduction_cell="reduction")

                    print("Layer {0:>2d}: {1}".format(layer_id, x))
                    layers = [layers[-1], x]

                # auxiliary heads
                self.num_aux_vars = 0
                if (self.use_aux_heads and layer_id in self.aux_head_indices
                        and is_training):
                    print("Using aux_head at layer {0}".format(layer_id))
                    with tf.variable_scope("aux_head"):
                        aux_logits = tf.nn.relu(x)
                        if (aux_logits.get_shape()[2].value - 3) % 5 == 0:
                            aux_logits = tf.layers.average_pooling2d(
                                aux_logits, [5, 5], [3, 3],
                                "VALID",
                                data_format=self.actual_data_format)
                        else:
                            aux_logits = tf.layers.average_pooling2d(
                                aux_logits, [5, 5], [3, 3],
                                "SAME",
                                data_format=self.actual_data_format)

                        with tf.variable_scope("proj"):
                            inp_c = self._get_C(aux_logits)
                            w = create_weight("w", [1, 1, inp_c, 128])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("avg_pool"):
                            inp_c = self._get_C(aux_logits)
                            hw = self._get_HW(aux_logits)
                            w = create_weight("w", [hw, hw, inp_c, 768])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("fc"):
                            aux_logits = global_avg_pool(
                                aux_logits, data_format=self.data_format)
                            inp_c = aux_logits.get_shape()[1].value
                            w = create_weight("w", [inp_c, 10])
                            aux_logits = tf.matmul(aux_logits, w)
                            self.aux_logits = aux_logits

                    aux_head_variables = [
                        var for var in tf.trainable_variables()
                        if (var.name.startswith(self.name)
                            and "aux_head" in var.name)
                    ]
                    self.num_aux_vars = count_model_params(aux_head_variables)
                    print("Aux head uses {0} params".format(self.num_aux_vars))

            x = tf.nn.relu(x)
            x = global_avg_pool(x, data_format=self.data_format)
            if is_training and self.keep_prob is not None and self.keep_prob < 1.0:
                x = tf.nn.dropout(x, self.keep_prob)
            with tf.variable_scope("fc"):
                inp_c = x.get_shape()[1].value
                w = create_weight("w", [inp_c, 10])
                x = tf.matmul(x, w)
        return x
예제 #7
0
def train(config, model_dir, writer):
    """
    Function train and evaluate a part segmentation model for
    the Shapenet dataset. The training parameters are specified
    in the config file (for more details see config/config.py).

    :param config: Dictionary with configuration paramters
    :param model_dir: Checkpoint save directory
    :param writer: Tensorboard SummaryWritter object
    """
    phases = ['train', 'test']
    # phases = ['test', 'train']
    datasets, dataloaders, num_classes = ds.get_s3dis_dataloaders(
        root_dir=config['root_dir'],
        phases=phases,
        batch_size=config['batch_size'],
        category=config['category'],
        augment=config['augment'])

    # add number of classes to config
    config['num_classes'] = num_classes

    # we now set GPU training parameters
    # if the given index is not available then we use index 0
    # also when using multi gpu we should specify index 0
    if config['gpu_index'] + 1 > torch.cuda.device_count(
    ) or config['multi_gpu']:
        config['gpu_index'] = 0
    logging.info('Using GPU cuda:{}, script PID {}'.format(
        config['gpu_index'], os.getpid()))
    if config['multi_gpu']:
        logging.info('Training on multi-GPU mode with {} devices'.format(
            torch.cuda.device_count()))
    device = torch.device('cuda:{}'.format(config['gpu_index']))

    # we load the model defined in the config file
    # todo: now the code is IO bound. No matter which network we use, it is similar speed.
    model = res.sfc_resnet_8(in_channels=config['in_channels'],
                             num_classes=config['num_classes'],
                             kernel_size=config['kernel_size'],
                             channels=config['channels'],
                             use_tnet=config['use_tnet'],
                             n_points=config['n_points']).to(device)
    logging.info('the number of params is {: .2f} M'.format(
        utl.count_model_params(model) / (1e6)))
    # if use multi_gpu then convert the model to DataParallel
    if config['multi_gpu']:
        model = nn.DataParallel(model)

    # create optimizer, loss function, and lr scheduler
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config['lr'],
                                 weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss().to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config['lr_decay'],
        patience=config['lr_patience'],
        verbose=True)  # verbose. recommended to use.

    logging.info('Config {}'.format(config))
    logging.info(
        'TB logs and checkpoint will be saved in {}'.format(model_dir))

    utl.dump_config_details_to_tensorboard(writer, config)

    # create metric trackers: we track lass, class accuracy, and overall accuracy
    trackers = {
        x: {
            'loss': metrics.LossMean(),
            'cm':
            metrics.ConfusionMatrix(num_classes=int(config['num_classes']))
        }
        for x in phases
    }

    # create initial best state object
    best_state = {
        'config': config,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict() if scheduler else None,
        'train_loss': float('inf'),
        'test_loss': float('inf'),
        'train_mIoU': 0.0,
        'test_mIoU': 0.0,
        'convergence_epoch': 0,
        'num_epochs_since_best_acc': 0
    }

    # now we train!
    for epoch in range(config['max_epochs']):
        for phase in phases:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            # reset metrics
            trackers[phase]['loss'].reset()
            trackers[phase]['cm'].reset()

            # use tqdm to show progress and print message
            # this is for loadding our new data format
            for step_number, batchdata in enumerate(
                    tqdm(dataloaders[phase],
                         desc='[{}/{}] {} '.format(epoch + 1,
                                                   config['max_epochs'],
                                                   phase))):
                data = torch.cat((batchdata.pos, batchdata.x),
                                 dim=2).transpose(1, 2).to(device,
                                                           dtype=torch.float)
                label = batchdata.y.to(device, dtype=torch.long)
                # should we release the memory?
                # todo: add data augmentation

                # compute gradients on train only
                with torch.set_grad_enabled(phase == 'train'):
                    out = model(data)
                    loss = criterion(out, label)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # now we update metrics
                trackers[phase]['loss'].update(average_loss=loss,
                                               batch_size=data.size(0))
                trackers[phase]['cm'].update(y_true=label, y_logits=out)

            # compare with my metrics
            epoch_loss = trackers[phase]['loss'].result()
            epoch_iou = trackers[phase]['cm'].result(metric='iou').mean()

            # we update our learning rate scheduler if loss does not improve
            if phase == 'train' and scheduler:
                scheduler.step(epoch_loss)
                writer.add_scalar('params/lr', optimizer.param_groups[0]['lr'],
                                  epoch + 1)

            # log current results and dump in Tensorboard
            logging.info(
                '[{}/{}] {} Loss: {:.2e}. mIOU {:.4f} \t best testing mIOU {:.4f}'
                .format(epoch + 1, config['max_epochs'], phase, epoch_loss,
                        epoch_iou, best_state['test_mIoU']))

            writer.add_scalar('loss/epoch_{}'.format(phase), epoch_loss,
                              epoch + 1)
            writer.add_scalar('mIoU/epoch_{}'.format(phase), epoch_iou,
                              epoch + 1)

        # after each epoch we update best state values as needed
        # first we save our state when we get better test accuracy
        test_iou = trackers['test']['cm'].result(metric='iou').mean()
        if best_state['test_mIoU'] > test_iou:
            best_state['num_epochs_since_best_acc'] += 1
        else:
            logging.info(
                'Got a new best model with iou {:.4f}'.format(test_iou))
            best_state = {
                'config': config,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict() if scheduler else None,
                'train_loss': trackers['train']['loss'].result(),
                'test_loss': trackers['test']['loss'].result(),
                'train_mIoU':
                trackers['train']['cm'].result(metric='iou').mean(),
                'test_mIoU': test_iou,
                'convergence_epoch': epoch + 1,
                'num_epochs_since_best_acc': 0
            }

            file_name = os.path.join(model_dir, 'best_state.pth')
            torch.save(best_state, file_name)
            logging.info('saved checkpoint in {}'.format(file_name))

        # we check for early stopping when we have trained a min number of epochs
        if epoch >= config['min_epochs'] and best_state[
                'num_epochs_since_best_acc'] >= config['early_stopping']:
            logging.info('Accuracy did not improve for {} iterations!'.format(
                config['early_stopping']))
            logging.info('[Early stopping]')
            break

    utl.dump_best_model_metrics_to_tensorboard(writer, phases, best_state)

    logging.info('************************** DONE **************************')
예제 #8
0
    def __init__(self, config):

        self.config = config

        # log directory
        if self.config.experiment_name=='':
            self.experiment_name = current_time
        else:
            self.experiment_name = self.config.experiment_name

        self.log_dir = config.log_dir+'/'+self.experiment_name
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # save config settings to file
        with open(self.log_dir+'/train_config.txt','w') as f:
            print('------------- training configuration -------------', file=f)
            for k, v in vars(config).items():
                print(('{}: {}').format(k, v), file=f)
            print(' ... loading training configuration ... ')
            print(' ... saving training configuration to {}'.format(f))

        self.train_data_root = self.config.data_root

        # training samples
        self.train_sample_dir = self.log_dir+'/samples_train'

        # checkpoints
        self.ckpt_dir = self.log_dir+'/ckpts'
        print('#######################')

        print("checkpoint dir:",self.ckpt_dir)
        print('#######################')

        # tensorboard
        if self.config.tb_logging:
            self.tb_dir = self.log_dir+'/tensorboard'

        self.use_cuda = use_cuda
        self.nz = config.nz
        self.nc = config.nc
        self.optimizer = config.optimizer
        self.batch_size_table = config.batch_size_table
        self.lr = config.lr
        self.d_eps_penalty = config.d_eps_penalty
        self.acgan = config.acgan
        self.max_resl = int(np.log2(config.max_resl))
        self.nframes_in = config.nframes_in
        self.nframes_pred = config.nframes_pred
        self.nframes = self.nframes_in+self.nframes_pred
        self.ext = config.ext
        self.nworkers = 4
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.complete = 0.0
        self.x_add_noise = config.x_add_noise
        self.fadein = {'G':None, 'D':None}
        self.init_resl = 2
        self.init_img_size = int(pow(2, self.init_resl))

        # initialize model G as FutureGenerator from model.py
        self.G = model.FutureGenerator(config)

        # initialize model D as Discriminator from model.py
        self.D = model.Discriminator(config)

        # define losses
        if self.config.loss=='lsgan':
            self.criterion = torch.nn.MSELoss()

        elif self.config.loss=='gan':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()

        elif self.config.loss=='wgan_gp':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Loss is undefined! Please set one of the following: `gan`, `lsgan` or `wgan_gp`')

        # check --use_ckpt
        # if --use_ckpt==False: build initial model
        # if --use_ckpt==True: load and build model from specified checkpoints
        if self.config.use_ckpt==False:

            print(' ... creating initial models ... ')

            # set initial model parameters
            self.resl = self.init_resl
            self.start_resl = self.init_resl
            self.globalIter = 0
            self.nsamples = 0
            self.stack = 0
            self.epoch = 0
            self.iter_start = 0
            self.phase = 'init'
            self.flag_flush = False

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()

            # count model parameters
            nparams_g = count_model_params(self.G)
            nparams_d = count_model_params(self.D)

            # save initial model structure to file
            with open(self.log_dir+'/initial_model_structure_{}x{}.txt'.format(self.init_img_size, self.init_img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), ', Batch size: ', self.batch_size, file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('Number of Generator`s model parameters: ', file=f)
                print(nparams_g, file=f)
                print('Number of Discriminator`s model parameters: ', file=f)
                print(nparams_d, file=f)
                print('--------------------------------------------------', file=f)
                print('Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... initial models have been built successfully ... ')
                print(' ... saving initial model strutures to {}'.format(f))

            # ship everything to cuda and parallelize for ngpu>1
            if self.use_cuda:
                self.criterion = self.criterion.cuda()
                torch.cuda.manual_seed(config.random_seed)
                if config.ngpu==1:
                    self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                    self.D = torch.nn.DataParallel(self.D).cuda(device=0)
                else:
                    gpus = []
                    for i  in range(config.ngpu):
                        gpus.append(i)
                    self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                    self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        else:

            # re-ship everything to cuda
            if self.use_cuda:
                self.G = self.G.cuda()
                self.D = self.D.cuda()

            # load checkpoint
            print(' ... loading models from checkpoints ... {} and {}'.format(self.config.ckpt_path[0], self.config.ckpt_path[1]))
            self.ckpt_g = torch.load(self.config.ckpt_path[0])
            self.ckpt_d = torch.load(self.config.ckpt_path[1])

            # get model parameters
            self.resl = self.ckpt_g['resl']
            self.start_resl = int(self.ckpt_g['resl'])
            self.iter_start = self.ckpt_g['iter']+1
            self.globalIter = int(self.ckpt_g['globalIter'])
            self.stack = int(self.ckpt_g['stack'])
            self.nsamples = int(self.ckpt_g['nsamples'])
            self.epoch = int(self.ckpt_g['epoch'])
            self.fadein['G'] = self.ckpt_g['fadein']
            self.fadein['D'] = self.ckpt_d['fadein']
            self.phase = self.ckpt_d['phase']
            self.complete  = self.ckpt_d['complete']
            self.flag_flush = self.ckpt_d['flag_flush']
            img_size = int(pow(2, floor(self.resl)))

            # get model structure
            self.G = self.ckpt_g['G_structure']
            self.D = self.ckpt_d['D_structure']

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()
            self.schedule_resl()
            self.nsamples = int(self.ckpt_g['nsamples'])

            # save loaded model structure to file
            with open(self.log_dir+'/resumed_model_structure_{}x{}.txt'.format(img_size, img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... models have been loaded successfully from checkpoints ... ')
                print(' ... saving resumed model strutures to {}'.format(f))

            # load model state_dict
            self.G.load_state_dict(self.ckpt_g['state_dict'])
            self.D.load_state_dict(self.ckpt_d['state_dict'])

            # load optimizer state dict
            lr = self.lr
            for i in range(1,int(floor(self.resl))-1):
                self.lr = lr*(self.config.lr_decay**i)
#            self.opt_g.load_state_dict(self.ckpt_g['optimizer'])
#            self.opt_d.load_state_dict(self.ckpt_d['optimizer'])
#            for param_group in self.opt_g.param_groups:
#                self.lr = param_group['lr']

        # tensorboard logging
        self.tb_logging = self.config.tb_logging
        if self.tb_logging==True:
            if not os.path.exists(self.tb_dir):
                os.makedirs(self.tb_dir)
            self.logger = Logger(self.tb_dir)
예제 #9
0
def train(params):
    # ============
    # Preparations
    # ============
    gc.collect()
    ray.init(log_to_driver=False, local_mode=False,
             num_gpus=1)  # or, ray.init()
    if not params.use_pretrain:
        # algorithm ingredients instantiation
        seed = params.seed
        actor = gen_actor(params.env_name, params.policy_params.hidden_dim)
        critic = gen_critic(params.env_name, params.policy_params.hidden_dim)
        optimizer = torch.optim.Adam(list(actor.parameters()) +
                                     list(critic.parameters()),
                                     lr=params.policy_params.learning_rate)
        rollout_time, update_time = AverageMeter(), AverageMeter()
        iteration_pretrain = 0
        # set random seed (for reproducing experiment)
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
    else:
        # build models
        actor = gen_actor(params.env_name,
                          params.policy_params.hidden_dim).cuda()
        critic = gen_critic(params.env_name,
                            params.policy_params.hidden_dim).cuda()
        optimizer = torch.optim.Adam(list(actor.parameters()) +
                                     list(critic.parameters()),
                                     lr=0.0001)
        # load models
        print("\n\nLoading training checkpoint...")
        print("------------------------------")
        load_path = os.path.join('./save/model', params.pretrain_file)
        checkpoint = torch.load(load_path)
        seed = checkpoint['seed']
        actor.load_state_dict(checkpoint['actor_state_dict'])
        actor.train()
        critic.load_state_dict(checkpoint['critic_state_dict'])
        critic.train()
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        [rollout_time, update_time] = checkpoint['time_recorder']
        iteration_pretrain = checkpoint['iteration']
        # >> set random seed (for reproducing experiment)
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        print("Loading finished!")
        print("------------------------------\n\n")
    rolloutmem = RolloutMemory(
        params.policy_params.envs_num * params.policy_params.horizon,
        params.env_name)
    envs = [
        ParallelEnv.remote(params.env_name, i)
        for i in range(params.policy_params.envs_num)
    ]
    for i in range(len(envs)):
        envs[i].seed.remote(seed=seed + i)
    tb = SummaryWriter()
    # ============
    # Training
    # ============
    # >> training loop
    print("----------------------------------")
    print("Training model with {} parameters...".format(
        count_model_params(actor) + count_model_params(critic)))
    print("----------------------------------")
    time_start = time.time()
    for iteration in range(int(params.iter_num - iteration_pretrain)):
        # collect rollouts from current policy
        rolloutmem.reset()
        iter_start_time = time.time()
        mean_iter_reward = rollout(rolloutmem, envs, actor, critic, params)
        # optimize by gradient descent
        update_start_time = time.time()
        loss, policy_loss, critic_loss, entropy_loss, advantage, ratio, surr1, surr2, epochs_len = \
            None, None, None, None, None, None, None, None, None,
        for epoch in range(params.policy_params.epochs_num):
            loss, policy_loss, critic_loss, entropy_loss, advantage, ratio, surr1, surr2, epochs_len = \
                optimize_step(optimizer, rolloutmem, actor, critic, params, iteration)
        iter_end_time = time.time()
        tb = logger_scalar(tb, iteration + iteration_pretrain, loss,
                           policy_loss, critic_loss, entropy_loss, advantage,
                           ratio, surr1, surr2, epochs_len, mean_iter_reward,
                           time_start)
        # tb = logger_histogram(tb, iteration + iteration_pretrain, actor, critic)
        rollout_time.update(update_start_time - iter_start_time)
        update_time.update(iter_end_time - update_start_time)
        tb.add_scalar('rollout_time', rollout_time.val,
                      iteration + iteration_pretrain)
        tb.add_scalar('update_time', update_time.val,
                      iteration + iteration_pretrain)
        print(
            'it {}: avgR: {:.3f} avgL: {:.3f} | rollout_time: {:.3f}sec update_time: {:.3f}sec'
            .format(iteration + iteration_pretrain, mean_iter_reward,
                    epochs_len, rollout_time.val, update_time.val))
        # save rollout video
        if (iteration + 1) % int(params.plotting_iters) == 0 \
                and iteration > 0 \
                and params.log_video \
                and params.env_name not in envnames_classiccontrol:
            log_policy_rollout(
                params, actor, params.env_name,
                'iter-{}'.format(iteration + iteration_pretrain))
        # save model
        if (iteration + 1) % int(
                params.checkpoint_iter
        ) == 0 and iteration > 0 and params.save_checkpoint:
            save_model(params.prefix, iteration, iteration_pretrain, seed,
                       actor, critic, optimizer, rollout_time, update_time)
    # save rollout videos
    if params.log_video:
        save_model(params.prefix, params.iter_num, iteration_pretrain, seed,
                   actor, critic, optimizer, rollout_time, update_time)
        if params.env_name not in envnames_classiccontrol:
            for i in range(3):
                log_policy_rollout(params, actor, params.env_name,
                                   'final-{}'.format(i))
예제 #10
0
    def _model(self, images, is_training, reuse=tf.AUTO_REUSE):

        if self.fixed_arc is None:
            is_training = True

        with tf.variable_scope(self.name, reuse=reuse):
            # the first two inputs
            with tf.variable_scope("stem_conv"):
                w = create_weight("w", [3, 3, 3, self.out_filters * 3])
                x = tf.nn.conv2d(images,
                                 w, [1, 1, 1, 1],
                                 "SAME",
                                 data_format=self.data_format)
                x = batch_norm(x, is_training, data_format=self.data_format)

            layers = [x, x]

            out_filters = self.out_filters
            for layer_id in range(self.num_layers + 2):
                with tf.variable_scope("layer_{0}".format(layer_id)):
                    if layer_id not in self.pool_layers:
                        if self.fixed_arc is None:
                            x = self._enas_layer(layer_id, layers,
                                                 self.normal_arc, out_filters)
                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.normal_arc,
                                out_filters,
                                1,
                                is_training,
                                normal_or_reduction_cell="normal")
                    else:
                        out_filters *= 2
                        if self.fixed_arc is None:
                            x = self._factorized_reduction(
                                x, out_filters, 2, is_training)
                            layers = [layers[-1], x]
                            x = self._enas_layer(layer_id, layers,
                                                 self.reduce_arc, out_filters)
                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.reduce_arc,
                                out_filters,
                                2,
                                is_training,
                                normal_or_reduction_cell="reduction")
                    tf.logging.info("Layer {0:>2d}: {1}".format(layer_id, x))
                    layers = [layers[-1], x]

                # auxiliary heads
                self.num_aux_vars = 0
                if (self.use_aux_heads and layer_id in self.aux_head_indices
                        and is_training):
                    tf.logging.info(
                        "Using aux_head at layer {0}".format(layer_id))
                    with tf.variable_scope("aux_head"):
                        aux_logits = tf.nn.relu(x)
                        aux_logits = tf.layers.average_pooling2d(
                            aux_logits, [5, 5], [3, 3],
                            "VALID",
                            data_format=self.actual_data_format)
                        with tf.variable_scope("proj"):
                            inp_c = self._get_C(aux_logits)
                            w = create_weight("w", [1, 1, inp_c, 128])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("avg_pool"):
                            inp_c = self._get_C(aux_logits)
                            hw = self._get_HW(aux_logits)
                            w = create_weight("w", [hw, hw, inp_c, 768])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("fc"):
                            aux_logits = global_avg_pool(
                                aux_logits, data_format=self.data_format)
                            inp_c = aux_logits.get_shape()[1].value
                            w = create_weight("w", [inp_c, 10])
                            aux_logits = tf.matmul(aux_logits, w)
                            self.aux_logits = aux_logits

                    aux_head_variables = [
                        var for var in tf.trainable_variables()
                        if (var.name.startswith(self.name)
                            and "aux_head" in var.name)
                    ]
                    self.num_aux_vars = count_model_params(aux_head_variables)
                    tf.logging.info("Aux head uses {0} params".format(
                        self.num_aux_vars))

            x = tf.nn.relu(x)
            x = global_avg_pool(x, data_format=self.data_format)
            if is_training and self.keep_prob is not None and self.keep_prob < 1.0:
                x = tf.nn.dropout(x, self.keep_prob)
            with tf.variable_scope("fc"):
                inp_c = self._get_C(x)
                w = create_weight("w", [inp_c, 10])
                x = tf.matmul(x, w)
        return x
예제 #11
0
 def num_params(self):
     return count_model_params(self.actor) + count_model_params(self.critic)
예제 #12
0
print(args)
train_loader, val_loader, test_loader = data_generator(root, batch_size)

model = MNIST_Classifier(input_channels,
                         n_classes,
                         args.hidden_size,
                         args.n_layers,
                         device,
                         tt=args.tt,
                         gru=args.gru,
                         n_cores=args.ncores,
                         tt_rank=args.ttrank,
                         naive_tt=args.naive_tt,
                         log_grads=args.log_grads,
                         extra_core=args.extra_core)
n_trainable, n_nontrainable = count_model_params(model)
print(
    "Model instantiated. Trainable params: {}, Non-trainable params: {}. Total: {}"
    .format(n_trainable, n_nontrainable, n_trainable + n_nontrainable))

# Setup activation and gradient logging
if args.log_grads:
    from tensorized_rnn.rnn_utils import ActivGradLogger as AGL

permute = torch.Tensor(np.random.permutation(784).astype(np.float64)).long()
if args.cuda:
    model.cuda()
    permute = permute.cuda()

# Set learning rate, optimizer, scheduler
lr = args.lr