Ejemplo n.º 1
0
  def __init__(self,
               data_dir=None,
               is_training=True,
               learning_rate=0.0002,
               beta1=0.9,
               reconstr_weight=0.85,
               smooth_weight=0.05,
               ssim_weight=0.15,
               icp_weight=0.0,
               batch_size=4,
               img_height=128,
               img_width=416,
               seq_length=3,
               legacy_mode=False):
    self.data_dir = data_dir
    self.is_training = is_training
    self.learning_rate = learning_rate
    self.reconstr_weight = reconstr_weight
    self.smooth_weight = smooth_weight
    self.ssim_weight = ssim_weight
    self.icp_weight = icp_weight
    self.beta1 = beta1
    self.batch_size = batch_size
    self.img_height = img_height
    self.img_width = img_width
    self.seq_length = seq_length
    self.legacy_mode = legacy_mode

    logging.info('data_dir: %s', data_dir)
    logging.info('learning_rate: %s', learning_rate)
    logging.info('beta1: %s', beta1)
    logging.info('smooth_weight: %s', smooth_weight)
    logging.info('ssim_weight: %s', ssim_weight)
    logging.info('icp_weight: %s', icp_weight)
    logging.info('batch_size: %s', batch_size)
    logging.info('img_height: %s', img_height)
    logging.info('img_width: %s', img_width)
    logging.info('seq_length: %s', seq_length)
    logging.info('legacy_mode: %s', legacy_mode)

    if self.is_training:
      self.reader = reader.DataReader(self.data_dir, self.batch_size,
                                      self.img_height, self.img_width,
                                      self.seq_length, NUM_SCALES)
      self.build_train_graph()
    else:
      self.build_depth_test_graph()
      self.build_egomotion_test_graph()

    # At this point, the model is ready.  Print some info on model params.
    util.count_parameters()
Ejemplo n.º 2
0
    def __init__(self,
                 data_dir=None,
                 is_training=True,
                 learning_rate=0.0002,
                 beta1=0.9,
                 reconstr_weight=0.85,
                 smooth_weight=0.05,
                 ssim_weight=0.15,
                 icp_weight=0.0,
                 batch_size=4,
                 img_height=128,
                 img_width=416,
                 seq_length=3,
                 legacy_mode=False):
        self.data_dir = data_dir
        self.is_training = is_training
        self.learning_rate = learning_rate
        self.reconstr_weight = reconstr_weight
        self.smooth_weight = smooth_weight
        self.ssim_weight = ssim_weight
        self.icp_weight = icp_weight
        self.beta1 = beta1
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.seq_length = seq_length
        self.legacy_mode = legacy_mode

        logging.info('data_dir: %s', data_dir)
        logging.info('learning_rate: %s', learning_rate)
        logging.info('beta1: %s', beta1)
        logging.info('smooth_weight: %s', smooth_weight)
        logging.info('ssim_weight: %s', ssim_weight)
        logging.info('icp_weight: %s', icp_weight)
        logging.info('batch_size: %s', batch_size)
        logging.info('img_height: %s', img_height)
        logging.info('img_width: %s', img_width)
        logging.info('seq_length: %s', seq_length)
        logging.info('legacy_mode: %s', legacy_mode)

        if self.is_training:
            self.reader = reader.DataReader(self.data_dir, self.batch_size,
                                            self.img_height, self.img_width,
                                            self.seq_length, NUM_SCALES)
            self.build_train_graph()
        else:
            self.build_depth_test_graph()
            self.build_egomotion_test_graph()

        # At this point, the model is ready.  Print some info on model params.
        util.count_parameters()
Ejemplo n.º 3
0
def print_params():
    print('---------------')
    total_count = util.count_parameters()
    print(f'number of parameters in vae: {total_count}')

    params = {
        scope: util.count_parameters(scope)
        for scope in ['encoder', 'decoder', 'upsampler']
    }

    params['auxiliary nodes'] = total_count - sum(
        (count for _, count in params.items() if count is not None))

    for scope, count in params.items():
        print('---------------')
        print(f'number of parameters in {scope}: {count}')

    print('---------------')
def main():
    path = "C:\\Users\\ji\\Documents\\FCN-VGG16\\models\\FCNs-BCEWithLogits_batch1_epoch10000_RMSprop_scheduler-step50-gamma0.5_lr0.001_momentum0.5_w_decay1e-05_input_size484"
    model = torch.load(path)
    model.eval()
    print("num para")
    print(util.count_parameters(model))
    root_dir = ".\\"
    train_file = os.path.join(root_dir, "train_one.csv")
    train_data = ScanNet2d(csv_file=train_file,
                           phase='train',
                           MeanRGB=np.array([0.0, 0.0, 0.0]))

    train_loader = DataLoader(train_data,
                              batch_size=1,
                              shuffle=True,
                              num_workers=1)
    n_class = 40
    use_gpu = True
    total_ious = []
    pixel_accs = []
    #print("len of val data loader :{}".format(len(val_loader)))
    for iter, batch in enumerate(train_loader):
        # print(iter)
        if use_gpu:
            inputs = Variable(batch['X'].cuda())
        else:
            inputs = Variable(batch['X'])

        output = model(inputs)
        output = output.data.cpu().numpy()

        N, _, h, w = output.shape
        pred = output.transpose(0, 2, 3, 1).reshape(
            -1, n_class + 1).argmax(axis=1).reshape(N, h, w)
        # print(pred.shape)

        target = batch['l'].cpu().numpy().reshape(N, h, w)
        for p, t in zip(pred, target):
            ious, count, ds = iou(p, t, n_class)
            total_ious.append(ious)
            pixel_accs.append(pixel_acc(p, t))
    # preds_v, targets_v = util.visulaize_output(outputs,targets,color_mapping,n_class)
    # writer.add_images('train/predictions',torch.from_numpy(preds_v),dataformats='NHWC')
    # writer.add_images('train/targets',torch.from_numpy(targets_v),dataformats='NHWC')
    total_ious = np.array(total_ious).T
    ious = np.nanmean(total_ious, axis=1)
    pixel_accs = np.array(pixel_accs).mean()
    meanIoU = np.nanmean(ious)

    print(total_ious)
    #print(ious)
    print(count)
    print(ds)
    print(np.nansum(ious) / count)
    print("pix_acc: {}, meanIoU: {}".format(pixel_accs, np.nanmean(ious)))
Ejemplo n.º 5
0
def all_count_params(model_name, num_layers, kernel_sizes, channel_size):
    d = []
    for l in sorted(num_layers):
        k_list = []
        d.append(k_list)
        for k in sorted(kernel_sizes):
            if l > 5 and k > 15:
                k_list.append(float('nan'))
                continue
            elif l == 5 and k > 20:
                k_list.append(float('nan'))
                continue
            elif l == 4 and k > 25:
                k_list.append(float('nan'))
                continue
            elif l == 3 and k > 50:
                k_list.append(float('nan'))
                continue

            model = load_model(model_name, channel_size, k, l)
            params = util.count_parameters(model)
            k_list.append(params)
    return d
Ejemplo n.º 6
0
def main():
    # set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    # load data
    device = torch.device(args.device)
    sensor_ids, sensor_id_to_ind, adj_mx = util.load_adj(
        args.adjdata, args.adjtype)
    # suffix = '_filtered_we'  # _filtered_we, _filtered_ew
    eR_seq_size = 24  # 24
    error_size = 6
    dataloader = util.load_dataset(args.data,
                                   args.batch_size,
                                   args.batch_size,
                                   args.batch_size,
                                   eRec=args.eRec,
                                   eR_seq_size=eR_seq_size,
                                   suffix=args.suffix)
    scaler = dataloader['scaler']

    if args.retrain:
        dl_train = util.load_dataset(args.data,
                                     args.batch_size,
                                     args.batch_size,
                                     args.batch_size,
                                     eRec=args.eRec,
                                     eR_seq_size=eR_seq_size,
                                     suffix=args.suffix_train)
        scaler = dl_train['scaler']

    blocks = int(dataloader[f'x_train{args.suffix}'].shape[-3] /
                 3)  # Every block reduce the input sequence size by 3.
    print(f'blocks = {blocks}')

    supports = [torch.tensor(i).to(device) for i in adj_mx]

    print(args)

    if args.randomadj:
        adjinit = None
    else:
        adjinit = supports[0]

    if args.aptonly:
        supports = None

    engine = trainer(scaler,
                     args.in_dim,
                     args.seq_length,
                     args.num_nodes,
                     args.nhid,
                     args.dropout,
                     args.learning_rate,
                     args.weight_decay,
                     device,
                     supports,
                     args.gcn_bool,
                     args.addaptadj,
                     adjinit,
                     blocks,
                     eRec=args.eRec,
                     retrain=args.retrain,
                     checkpoint=args.checkpoint,
                     error_size=error_size)

    if args.retrain:
        dataloader['val_loader'] = dataloader['train_loader']

    print("start training...", flush=True)
    his_loss = []
    val_time = []
    train_time = []
    for i in range(1, args.epochs + 1):
        #if i % 10 == 0:
        #lr = max(0.000002,args.learning_rate * (0.1 ** (i // 10)))
        #for g in engine.optimizer.param_groups:
        #g['lr'] = lr
        train_loss = []
        train_mape = []
        train_rmse = []
        t1 = time.time()
        dataloader['train_loader'].shuffle()
        for iter, (x,
                   y) in enumerate(dataloader['train_loader'].get_iterator()):
            trainx = torch.Tensor(x).to(device)
            trainy = torch.Tensor(y).to(device)
            if args.eRec:
                trainx = trainx.transpose(0, 1)
                trainy = trainy.transpose(0, 1)
            trainx = trainx.transpose(-3, -1)
            trainy = trainy.transpose(-3, -1)
            # print(f'trainx.shape = {trainx.shape}')
            # print(f'trainy.shape = {trainy.shape}')
            # print(f'trainy.shape final = {trainy[:,0,:,:].shape}')
            if args.eRec:
                metrics = engine.train(trainx, trainy[:, :, 0, :, :])
            else:
                metrics = engine.train(trainx, trainy[:, 0, :, :])
            train_loss.append(metrics[0])
            train_mape.append(metrics[1])
            train_rmse.append(metrics[2])
            if iter % args.print_every == 0:
                log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}'
                print(log.format(iter, train_loss[-1], train_mape[-1],
                                 train_rmse[-1]),
                      flush=True)
        t2 = time.time()
        train_time.append(t2 - t1)
        #validation
        valid_loss = []
        valid_mape = []
        valid_rmse = []

        s1 = time.time()
        for iter, (x, y) in enumerate(dataloader['val_loader'].get_iterator()):
            testx = torch.Tensor(x).to(device)
            testy = torch.Tensor(y).to(device)
            if args.eRec:
                testx = testx.transpose(0, 1)
                testy = testy.transpose(0, 1)
            testx = testx.transpose(-3, -1)
            testy = testy.transpose(-3, -1)
            if args.eRec:
                metrics = engine.eval(testx, testy[:, :, 0, :, :])
            else:
                metrics = engine.eval(testx, testy[:, 0, :, :])
            valid_loss.append(metrics[0])
            valid_mape.append(metrics[1])
            valid_rmse.append(metrics[2])
        s2 = time.time()
        log = 'Epoch: {:03d}, Inference Time: {:.4f} secs'
        print(log.format(i, (s2 - s1)))
        val_time.append(s2 - s1)
        mtrain_loss = np.mean(train_loss)
        mtrain_mape = np.mean(train_mape)
        mtrain_rmse = np.mean(train_rmse)

        mvalid_loss = np.mean(valid_loss)
        mvalid_mape = np.mean(valid_mape)
        mvalid_rmse = np.mean(valid_rmse)
        his_loss.append(mvalid_loss)

        log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch'
        print(log.format(i, mtrain_loss, mtrain_mape, mtrain_rmse, mvalid_loss,
                         mvalid_mape, mvalid_rmse, (t2 - t1)),
              flush=True)
        torch.save(
            engine.model.state_dict(), args.save + "_epoch_" + str(i) + "_" +
            str(round(mvalid_loss, 2)) + ".pth")
    print("Average Training Time: {:.4f} secs/epoch".format(
        np.mean(train_time)))
    print("Average Inference Time: {:.4f} secs".format(np.mean(val_time)))

    #testing
    bestid = 82  # 24 hay que sumarle 1 para obtener el ID del modelo
    bestid = np.argmin(his_loss)
    engine.model.load_state_dict(
        torch.load(args.save + "_epoch_" + str(bestid + 1) + "_" +
                   str(round(his_loss[bestid], 2)) + ".pth"))
    # engine.model.load_state_dict(torch.load(args.save + f"_id_25_2.6_best_model.pth"))
    # engine.model.load_state_dict(torch.load(args.save + f"_exp1_best_2.6.pth"))

    #torch.save(engine.model.state_dict(), args.save + f"_id_{bestid+1}_best_model.pth")
    print(f'best_id = {bestid+1}')

    outputs = []
    realy = torch.Tensor(dataloader[f'y_test{args.suffix}']).to(device)
    #print(f'realy: {realy.shape}')
    if args.eRec:
        realy = realy.transpose(0, 1)
        realy = realy.transpose(-3, -1)[-1, :, 0, :, :]
        #print(f'realy2: {realy.shape}')
    else:
        realy = realy.transpose(-3, -1)[:, 0, :, :]
        #print(f'realy2: {realy.shape}')
    criterion = nn.MSELoss(reduction='none')  # L2 Norm
    criterion2 = nn.L1Loss(reduction='none')
    loss_mse_list = []
    loss_mae_list = []

    for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()):
        testx = torch.Tensor(x).to(device)
        testy = torch.Tensor(y).to(device)
        if args.eRec:
            testx = testx.transpose(0, 1)
            testy = testy.transpose(0, 1)
        testx = testx.transpose(-3, -1)
        testy = testy.transpose(-3, -1)
        with torch.no_grad():
            if args.eRec:
                preds = engine.model(testx, testy[:, :, 0:1, :, :],
                                     scaler).transpose(1, 3)
            else:
                preds = engine.model(testx).transpose(1, 3)

        #print(f'preds: {scaler.inverse_transform(torch.squeeze(preds.transpose(-3, -1))).shape}')
        #print(f'testy: {torch.squeeze(testy[:, 0:1, :, :].transpose(-3, -1)).shape}')
        if args.eRec:
            loss_mse = criterion(
                scaler.inverse_transform(torch.squeeze(preds.transpose(-3,
                                                                       -1))),
                torch.squeeze(testy[-1, :, 0:1, :, :].transpose(-3, -1)))
            loss_mae = criterion2(
                scaler.inverse_transform(torch.squeeze(preds.transpose(-3,
                                                                       -1))),
                torch.squeeze(testy[-1, :, 0:1, :, :].transpose(-3, -1)))
        else:
            loss_mse = criterion(
                scaler.inverse_transform(torch.squeeze(preds.transpose(-3,
                                                                       -1))),
                torch.squeeze(testy[:, 0:1, :, :].transpose(-3, -1)))
            loss_mae = criterion2(
                scaler.inverse_transform(torch.squeeze(preds.transpose(-3,
                                                                       -1))),
                torch.squeeze(testy[:, 0:1, :, :].transpose(-3, -1)))

        loss_mse_list.append(loss_mse)
        loss_mae_list.append(loss_mae)

        outputs.append(preds.squeeze())

    loss_mse_list.pop(-1)
    loss_mae_list.pop(-1)
    loss_mse = torch.cat(loss_mse_list, 0)
    loss_mae = torch.cat(loss_mae_list, 0)
    #loss_mse = torch.squeeze(loss_mse).cpu()
    #loss_mae = torch.squeeze(loss_mae).cpu()
    loss_mse = loss_mse.cpu()
    loss_mae = loss_mae.cpu()
    print(f'loss_mae: {loss_mae.shape}')
    print(f'loss_mse: {loss_mae.shape}')

    res_folder = 'results/'
    original_stdout = sys.stdout
    with open(res_folder + f'loss_evaluation.txt', 'w') as filehandle:
        sys.stdout = filehandle  # Change the standard output to the file we created.
        count_parameters(engine.model)
        # loss_mae.shape --> (batch_size, seq_size, n_detect)
        print(' 1. ***********')
        print_loss('MSE', loss_mse)
        print(' 2. ***********')
        print_loss('MAE', loss_mae)
        print(' 3. ***********')
        print_loss_sensor('MAE', loss_mae)
        print(' 5. ***********')
        print_loss_seq('MAE', loss_mae)
        print(' 6. ***********')
        print_loss_sensor_seq('MAE', loss_mae)

        sys.stdout = original_stdout  # Reset the standard output to its original value

    with open(res_folder + f'loss_evaluation.txt', 'r') as filehandle:
        print(filehandle.read())

    yhat = torch.cat(outputs, dim=0)
    #print(f'yhat: {yhat.shape}')
    yhat = yhat[:realy.size(0), ...]
    #print(f'yhat2: {yhat.shape}')

    print("Training finished")
    #print("The valid loss on best model is", str(round(his_loss[bestid],4)))

    amae = []
    amape = []
    armse = []
    for i in range(args.seq_length):
        pred = scaler.inverse_transform(yhat[:, :, i])
        real = realy[:, :, i]
        metrics = util.metric(pred, real)
        log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
        print(log.format(i + 1, metrics[0], metrics[1], metrics[2]))
        amae.append(metrics[0])
        amape.append(metrics[1])
        armse.append(metrics[2])

    log = 'On average over {:.4f} horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
    print(
        log.format(args.seq_length, np.mean(amae), np.mean(amape),
                   np.mean(armse)))
    torch.save(
        engine.model.state_dict(), args.save + "_exp" + str(args.expid) +
        "_best_" + str(round(np.min(his_loss), 2)) + ".pth")
Ejemplo n.º 7
0
    def __init__(
        self,
        d_in=28 * 28,
        d_out=28 * 28,
        m=200,
        n=6,
        k=25,
        k_winner_cells=1,
        gamma=0.5,
        eps=0.5,
        activation_fn="tanh",
        embed_dim=0,
        vocab_size=0,
        dropout_p=0.0,
        decode_from_full_memory=False,
        debug_log_names=None,
        mask_shifted_pi=False,
        do_inhibition=True,
        boost_strat="rsm_inhibition",
        pred_gain=1.0,
        x_b_norm=False,
        boost_strength=1.0,
        mult_integration=False,
        boost_strength_factor=1.0,
        forget_mu=0.0,
        weight_sparsity=None,
        debug=False,
        visual_debug=False,
        use_bias=True,
        fpartition=None,
        balance_part_winners=False,
        **kwargs,
    ):
        """
        This class includes an attempted replication of the Recurrent Sparse Memory
        architecture suggested by by
        [Rawlinson et al 2019](https://arxiv.org/abs/1905.11589).

        Parameters allow experimentation with a wide array of adjustments to this model,
        both minor and major. Classes of models tested include:

        * "Adjusted" model with k-winners and column boosting, 2 cell winners,
            no inhibition
        * "Flattened" model with 1 cell per column, 1000 cols, 25 winners
            and multiplicative integration of FF & recurrent input
        * "Flat Partitioned" model with 120 winners, and cells partitioned into three
            functional types: ff only, recurrent only, and optionally a region that
            integrates both.

        :param d_in: Dimension of input
        :param m: Number of groups/columns
        :param n: Cells per group/column
        :param k: # of groups/columns to win in topk() (sparsity)
        :param k_winner_cells: # of winning cells per column
        :param gamma: Inhibition decay rate (0-1)
        :param eps: Integrated encoding decay rate (0-1)

        """
        super(RSMLayer, self).__init__()
        self.k = int(k)
        self.k_winner_cells = k_winner_cells
        self.m = m
        self.n = n
        self.gamma = gamma
        self.eps = eps
        self.d_in = d_in
        self.d_out = d_out
        self.dropout_p = float(dropout_p)
        self.forget_mu = float(forget_mu)

        self.total_cells = m * n
        self.flattened = self.total_cells == self.m

        # Tweaks
        self.activation_fn = activation_fn
        self.decode_from_full_memory = decode_from_full_memory
        self.boost_strat = boost_strat
        self.pred_gain = pred_gain
        self.x_b_norm = x_b_norm
        self.mask_shifted_pi = mask_shifted_pi
        self.do_inhibition = do_inhibition
        self.boost_strength = boost_strength
        self.boost_strength_factor = boost_strength_factor
        self.mult_integration = mult_integration
        self.fpartition = fpartition
        if isinstance(self.fpartition, float):
            # Handle simple single-param FF-percentage only
            # If fpartition is list, interpreted as [ff_pct, rec_pct]
            self.fpartition = [self.fpartition, 1.0 - self.fpartition]
        self.balance_part_winners = balance_part_winners
        self.weight_sparsity = weight_sparsity

        self.debug = debug
        self.visual_debug = visual_debug
        self.debug_log_names = debug_log_names

        self.dropout = nn.Dropout(p=self.dropout_p)

        self._build_input_layers_and_kwinners(use_bias=use_bias)

        decode_d_in = self.total_cells if self.decode_from_full_memory else m
        self.linear_d = nn.Linear(
            decode_d_in, d_out, bias=use_bias
        )  # Decoding through bottleneck

        print("Created model with %d trainable params" % count_parameters(self))
Ejemplo n.º 8
0
load_function = util_classes.get_init_func(model_name)
params_list = np.zeros((len(kernel_sizes)))
names = []
ks = []
nl = []
for i in range(len(kernel_sizes)):
    args = {
        "input_shape": input_shape,
        "kernel_size": kernel_sizes[i],
        "channel_size": 32,
        "num_layers": num_layers[i],
        "n_classes": 256
    }
    model = load_function(args)
    num_params = util.count_parameters(model)
    print(f"{model_name}_k{kernel_sizes[i]:3}_l{num_layers[i]:2}={num_params}")

    params_list[i] = num_params
    names.append(f"{model_name}_k{kernel_sizes[i]}_c32_l{num_layers[i]}")
    ks.append(kernel_sizes[i])
    nl.append(num_layers[i])

threshold = 100000

kernel_map = {}
layer_map = {}
min_max = {}
for i in range(len(params_list)):
    x_min = params_list[i] - threshold
    x_max = params_list[i] + threshold
Ejemplo n.º 9
0
  def __init__(self,
               data_dir=None,
               file_extension='png',
               is_training=True,
               learning_rate=0.0002,
               beta1=0.9,
               reconstr_weight=0.85,
               smooth_weight=0.05,
               ssim_weight=0.15,
               icp_weight=0.0,
               batch_size=4,
               img_height=128,
               img_width=416,
               seq_length=3,
               architecture=nets.RESNET,
               imagenet_norm=True,
               weight_reg=0.05,
               exhaustive_mode=False,
               random_scale_crop=False,
               flipping_mode=reader.FLIP_RANDOM,
               random_color=True,
               depth_upsampling=True,
               depth_normalization=True,
               compute_minimum_loss=True,
               use_skip=True,
               joint_encoder=True,
               build_sum=True,
               shuffle=True,
               input_file='train',
               handle_motion=False,
               equal_weighting=False,
               size_constraint_weight=0.0,
               train_global_scale_var=True):
    self.data_dir = data_dir
    self.file_extension = file_extension
    self.is_training = is_training
    self.learning_rate = learning_rate
    self.reconstr_weight = reconstr_weight
    self.smooth_weight = smooth_weight
    self.ssim_weight = ssim_weight
    self.icp_weight = icp_weight
    self.beta1 = beta1
    self.batch_size = batch_size
    self.img_height = img_height
    self.img_width = img_width
    self.seq_length = seq_length
    self.architecture = architecture
    self.imagenet_norm = imagenet_norm
    self.weight_reg = weight_reg
    self.exhaustive_mode = exhaustive_mode
    self.random_scale_crop = random_scale_crop
    self.flipping_mode = flipping_mode
    self.random_color = random_color
    self.depth_upsampling = depth_upsampling
    self.depth_normalization = depth_normalization
    self.compute_minimum_loss = compute_minimum_loss
    self.use_skip = use_skip
    self.joint_encoder = joint_encoder
    self.build_sum = build_sum
    self.shuffle = shuffle
    self.input_file = input_file
    self.handle_motion = handle_motion
    self.equal_weighting = equal_weighting
    self.size_constraint_weight = size_constraint_weight
    self.train_global_scale_var = train_global_scale_var

    logging.info('data_dir: %s', data_dir)
    logging.info('file_extension: %s', file_extension)
    logging.info('is_training: %s', is_training)
    logging.info('learning_rate: %s', learning_rate)
    logging.info('reconstr_weight: %s', reconstr_weight)
    logging.info('smooth_weight: %s', smooth_weight)
    logging.info('ssim_weight: %s', ssim_weight)
    logging.info('icp_weight: %s', icp_weight)
    logging.info('size_constraint_weight: %s', size_constraint_weight)
    logging.info('beta1: %s', beta1)
    logging.info('batch_size: %s', batch_size)
    logging.info('img_height: %s', img_height)
    logging.info('img_width: %s', img_width)
    logging.info('seq_length: %s', seq_length)
    logging.info('architecture: %s', architecture)
    logging.info('imagenet_norm: %s', imagenet_norm)
    logging.info('weight_reg: %s', weight_reg)
    logging.info('exhaustive_mode: %s', exhaustive_mode)
    logging.info('random_scale_crop: %s', random_scale_crop)
    logging.info('flipping_mode: %s', flipping_mode)
    logging.info('random_color: %s', random_color)
    logging.info('depth_upsampling: %s', depth_upsampling)
    logging.info('depth_normalization: %s', depth_normalization)
    logging.info('compute_minimum_loss: %s', compute_minimum_loss)
    logging.info('use_skip: %s', use_skip)
    logging.info('joint_encoder: %s', joint_encoder)
    logging.info('build_sum: %s', build_sum)
    logging.info('shuffle: %s', shuffle)
    logging.info('input_file: %s', input_file)
    logging.info('handle_motion: %s', handle_motion)
    logging.info('equal_weighting: %s', equal_weighting)
    logging.info('train_global_scale_var: %s', train_global_scale_var)

    if self.size_constraint_weight > 0 or not is_training:
      self.global_scale_var = tf.Variable(
          0.1, name='global_scale_var',
          trainable=self.is_training and train_global_scale_var,
          dtype=tf.float32,
          constraint=lambda x: tf.clip_by_value(x, 0, np.infty))

    if self.is_training:
      self.reader = reader.DataReader(self.data_dir, self.batch_size,
                                      self.img_height, self.img_width,
                                      self.seq_length, NUM_SCALES,
                                      self.file_extension,
                                      self.random_scale_crop,
                                      self.flipping_mode,
                                      self.random_color,
                                      self.imagenet_norm,
                                      self.shuffle,
                                      self.input_file)
      self.build_train_graph()
    else:
      self.build_depth_test_graph()
      self.build_egomotion_test_graph()
      if self.handle_motion:
        self.build_objectmotion_test_graph()

    # At this point, the model is ready. Print some info on model params.
    util.count_parameters()
# Define dataloaders
test_set = CSDataset('dataset/Cityscapes/test.csv',
                     transform=transforms.Compose(
                         [Rescale(img_size), CSToTensor()]))
test_loader = DataLoader(test_set,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=8,
                         drop_last=False)

# model and loss
# G = PSPNet(backend='resnet18', psp_size=512, pretrained=False).to(device)
G = PSPNetShareEarlyLayer(backend='resnet18shareearlylayer',
                          psp_size=512,
                          pretrained=False).to(device)
print(count_parameters(G))

if os.path.isfile(checkpoint_path):
    state = torch.load(checkpoint_path)
    G.load_state_dict(state['state_dict_G'])
else:
    print('No checkpoint found')
    exit()

G.eval()  # Set model to evaluate mode
# Iterate over data.
for i, temp_batch in enumerate(test_loader):

    temp_rgb = temp_batch['rgb'].float().to(device)
    temp_foregd = temp_batch['foregd'].long().to(device)
    temp_partial_bkgd = temp_batch['partial_bkgd'].long().squeeze().to(device)
Ejemplo n.º 11
0
                                              use_noise=args['--use_noise'],
                                              noise_sigma=float(
                                                  args['--noise_sigma']))

    video_discriminator = build_discriminator(args['--video_discriminator'],
                                              dim_categorical=dim_z_category,
                                              n_channels=n_channels,
                                              use_noise=args['--use_noise'],
                                              noise_sigma=float(
                                                  args['--noise_sigma']))

    if torch.cuda.is_available():
        generator.cuda()
        image_discriminator.cuda()
        video_discriminator.cuda()

    print('The number of parameters for Video disciminator is : {0}'.format(
        count_parameters(video_discriminator)))
    summary((3, 64, 64, 16), video_discriminator)

    trainer = Trainer(image_loader,
                      video_loader,
                      int(args['--print_every']),
                      int(args['--batches']),
                      args['<log_folder>'],
                      use_cuda=torch.cuda.is_available(),
                      use_infogan=args['--use_infogan'],
                      use_categories=args['--use_categories'])

    trainer.train(generator, image_discriminator, video_discriminator)
Ejemplo n.º 12
0
def main(config, progress):
    # save config
    with open("./log/configs.json", "a") as f:
        json.dump(config, f)
        f.write("\n")
    cprint("*"*80)
    cprint("Experiment progress: {0:.2f}%".format(progress*100))
    cprint("*"*80)
    metrics = {}

    # data hyper-params
    train_path = config["train_path"]
    valid_path = config["valid_path"]
    test_path = config["test_path"]
    dataset = train_path.split("/")[3]
    test_mode = bool(config["test_mode"])
    load_model_path = config["load_model_path"]
    save_model_path = config["save_model_path"]
    num_candidates = config["num_candidates"]
    num_personas = config["num_personas"]
    persona_path = config["persona_path"]
    max_sent_len = config["max_sent_len"]
    max_seq_len = config["max_seq_len"]
    PEC_ratio = config["PEC_ratio"]
    train_ratio = config["train_ratio"]
    if PEC_ratio != 0 and train_ratio != 1:
        raise ValueError("PEC_ratio or train_ratio not qualified!")
    
    # model hyper-params
    config_id = config["config_id"]
    model = config["model"]
    shared = bool(config["shared"])
    apply_interaction = bool(config["apply_interaction"])
    matching_method = config["matching_method"]
    aggregation_method = config["aggregation_method"]
    output_hidden_states = False
    
    # training hyper-params
    batch_size = config["batch_size"]
    epochs = config["epochs"]
    warmup_steps = config["warmup_steps"]
    gradient_accumulation_steps = config["gradient_accumulation_steps"]
    lr = config["lr"]
    weight_decay = 0
    seed = config["seed"]
    device = torch.device(config["device"])
    fp16 = bool(config["fp16"])
    fp16_opt_level = config["fp16_opt_level"]

    # set seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    if test_mode and load_model_path == "":
        raise ValueError("Must specify test model path when in test mode!")

    # load data
    cprint("Loading conversation data...")
    train = load_pickle(train_path)
    valid = load_pickle(valid_path)
    if test_mode:
        test = load_pickle(test_path)
        valid_path = test_path
        valid = test

    cprint("sample train data: ", train[0])
    cprint("sample valid data: ", valid[0])
    
    # tokenization
    cprint("Tokenizing ...")
    tokenizer = BertTokenizer.from_pretrained(model)
    cached_tokenized_train_path = train_path.replace(".pkl", "_tokenized.pkl")
    cached_tokenized_valid_path = valid_path.replace(".pkl", "_tokenized.pkl")
    if os.path.exists(cached_tokenized_train_path):
        cprint("Loading tokenized dataset from ", cached_tokenized_train_path)
        train = load_pickle(cached_tokenized_train_path)
    else:
        train = tokenize_conversations(train, tokenizer, max_sent_len)
        cprint("Saving tokenized dataset to ", cached_tokenized_train_path)
        save_pickle(train, cached_tokenized_train_path)

    if os.path.exists(cached_tokenized_valid_path):
        cprint("Loading tokenized dataset from ", cached_tokenized_valid_path)
        valid = load_pickle(cached_tokenized_valid_path)
    else:
        valid = tokenize_conversations(valid, tokenizer, max_sent_len)
        cprint("Saving tokenized dataset to ", cached_tokenized_valid_path)
        save_pickle(valid, cached_tokenized_valid_path)
    
    persona = None
    if num_personas > 0:
        cprint("Tokenizing persona sentences...")
        cached_tokenized_persona_path = persona_path.replace(".pkl", "_tokenized.pkl")
        if os.path.exists(cached_tokenized_persona_path):
            cprint("Loading tokenized persona from file...")
            persona = load_pickle(cached_tokenized_persona_path)
        else:
            cprint("Loading persona data...")
            persona = load_pickle(persona_path)
            all_speakers = set([s for conv in load_pickle(config["train_path"]) + \
                load_pickle(config["valid_path"]) + load_pickle(config["test_path"]) for s, sent in conv])
            cprint("Tokenizing persona data...")
            persona = tokenize_personas(persona, tokenizer, all_speakers, num_personas)
            cprint("Saving tokenized persona to file...")
            save_pickle(persona, cached_tokenized_persona_path)
        cprint("Persona dataset statistics (after tokenization):", len(persona))
        cprint("Sample tokenized persona:", list(persona.values())[0])        
    
    cprint("Sample tokenized data: ")
    cprint(train[0])
    cprint(valid[0])
        
    # select subsets of training and validation data for casualconversation
    cprint(dataset)
    if dataset == "casualconversation_v3":
        cprint("reducing dataset size...")
        train = train[:150000]
        valid = valid[:20000]

    if train_ratio != 1:
        num_train_examples = int(len(train) * train_ratio)
        cprint("reducing training set size to {0}...".format(num_train_examples))
        train = train[:num_train_examples]

    if PEC_ratio != 0:
        cprint("Replacing {0} of casual to PEC...".format(PEC_ratio))
        cprint(len(train))
        
        PEC_train_path = "./data/reddit_empathetic/combined_v3/train_cleaned_bert.pkl"
        PEC_persona_path = "./data/reddit_empathetic/combined_v3/persona_10.pkl"
        
        # load cached casual conversations and persona
        num_PEC_examples = int(len(train) * PEC_ratio)
        train[:num_PEC_examples] = load_pickle(PEC_train_path.replace(".pkl", "_tokenized.pkl"))[:num_PEC_examples]
        cprint(num_PEC_examples, len(train))
        
        if num_personas > 0:
            cprint("number of speakers before merging PEC and casual: ", len(persona))
            # merge persona
            PEC_persona = load_pickle(PEC_persona_path.replace(".pkl", "_tokenized.pkl"))
            for k,v in PEC_persona.items():
                if k not in persona:
                    persona[k] = v
            cprint("number of speakers after merging PEC and casual: ", len(persona))

    # create context and response
    train = create_context_and_response(train)
    valid = create_context_and_response(valid)
    cprint("Sample context and response: ")
    cprint(train[0])
    cprint(valid[0])

    # convert to token ids
    cprint("Converting conversations to ids: ")
    if not test_mode:
        train_dataset = convert_conversations_to_ids(train, persona, tokenizer, \
            max_seq_len, max_sent_len, num_personas)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, drop_last=True)
        t_total = len(train_dataloader) // gradient_accumulation_steps * epochs
        cprint(train_dataset[0])
    valid_dataset = convert_conversations_to_ids(valid, persona, tokenizer, \
        max_seq_len, max_sent_len, num_personas)
    valid_sampler = RandomSampler(valid_dataset)
    valid_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=num_candidates)

    # create model
    cprint("Building model...")
    model = BertModel.from_pretrained(model, output_hidden_states=output_hidden_states)
    cprint(model)
    cprint("number of parameters: ", count_parameters(model))

    if shared:
        cprint("number of encoders: 1")
        models = [model]
    else:
        if num_personas == 0:
            cprint("number of encoders: 2")
            # models = [model, copy.deepcopy(model)]
            models = [model, pickle.loads(pickle.dumps(model))]
        else:
            cprint("number of encoders: 3")
            # models = [model, copy.deepcopy(model), copy.deepcopy(model)]
            models = [model, pickle.loads(pickle.dumps(model)), pickle.loads(pickle.dumps(model))]
    
    if test_mode:
        cprint("Loading weights from ", load_model_path)
        model.load_state_dict(torch.load(load_model_path))
        models = [model]
    
    for i, model in enumerate(models):
        cprint("model {0} number of parameters: ".format(i), count_parameters(model))
        model.to(device)

    # optimization
    amp = None
    if fp16:
        from apex import amp
    
    no_decay = ["bias", "LayerNorm.weight"]
    optimizers = []
    schedulers = []
    for i, model in enumerate(models):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8)

        if fp16:
            model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
            models[i] = model
        optimizers.append(optimizer)
        
        if not test_mode:
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
            )
            schedulers.append(scheduler)

    if test_mode:
        # evaluation
        for model in models:
            model.eval()
        valid_iterator = tqdm(valid_dataloader, desc="Iteration")
        valid_loss, (valid_acc, valid_recall, valid_MRR) = evaluate_epoch(valid_iterator, models, \
            num_personas, gradient_accumulation_steps, device, dataset, 0, apply_interaction, matching_method, aggregation_method)
        cprint("test loss: {0:.4f}, test acc: {1:.4f}, test recall: {2}, test MRR: {3:.4f}"
            .format(valid_loss, valid_acc, valid_recall, valid_MRR))
        sys.exit()
    
    # training
    epoch_train_losses = []
    epoch_valid_losses = []
    epoch_valid_accs = []
    epoch_valid_recalls = []
    epoch_valid_MRRs = []
    cprint("***** Running training *****")
    cprint("Num examples =", len(train_dataset))
    cprint("Num Epochs =", epochs)
    cprint("Total optimization steps =", t_total)
    best_model_statedict = {}
    for epoch in range(epochs):
        cprint("Epoch", epoch+1)
        # training
        for model in models:
            model.train()
        train_iterator = tqdm(train_dataloader, desc="Iteration")
        train_loss, (train_acc, _, _) = train_epoch(train_iterator, models, num_personas, optimizers, \
            schedulers, gradient_accumulation_steps, device, fp16, amp, apply_interaction, matching_method, aggregation_method)
        epoch_train_losses.append(train_loss)
    
        # evaluation
        for model in models:
            model.eval()
        valid_iterator = tqdm(valid_dataloader, desc="Iteration")
        valid_loss, (valid_acc, valid_recall, valid_MRR) = evaluate_epoch(valid_iterator, models, \
            num_personas, gradient_accumulation_steps, device, dataset, epoch, apply_interaction, matching_method, aggregation_method)
        
        cprint("Config id: {7}, Epoch {0}: train loss: {1:.4f}, valid loss: {2:.4f}, train_acc: {3:.4f}, valid acc: {4:.4f}, valid recall: {5}, valid_MRR: {6:.4f}"
            .format(epoch+1, train_loss, valid_loss, train_acc, valid_acc, valid_recall, valid_MRR, config_id))
        
        epoch_valid_losses.append(valid_loss)
        epoch_valid_accs.append(valid_acc)
        epoch_valid_recalls.append(valid_recall)
        epoch_valid_MRRs.append(valid_MRR)

        if save_model_path != "":
            if epoch == 0:
                for k, v in models[0].state_dict().items():
                    best_model_statedict[k] = v.cpu()
            else:
                if epoch_valid_recalls[-1][0] == max([recall1 for recall1, _, _ in epoch_valid_recalls]):
                    for k, v in models[0].state_dict().items():
                        best_model_statedict[k] = v.cpu()


    config.pop("seed")
    config.pop("config_id")
    metrics["config"] = config
    metrics["score"] = max(epoch_valid_accs)
    metrics["epoch"] = np.argmax(epoch_valid_accs).item()
    metrics["recall"] = epoch_valid_recalls
    metrics["MRR"] = epoch_valid_MRRs

    if save_model_path:
        cprint("Saving model to ", save_model_path)
        torch.save(best_model_statedict, save_model_path)

    return metrics
Ejemplo n.º 13
0
seed = 0
epochs = 60
lr = 1e-3

criteria = nn.MSELoss()
model = W_PDE()
#model = Model()
model = model.cuda()
#optimizer = optim.LBFGS(model.parameters(), lr=lr)
optimizer = optim.Adam(model.parameters(), lr=lr)

xval = xval.cuda()
yval = yval.cuda()

# Trainin loop
print("Optimizing %d parameters on %s" % (util.count_parameters(model), 'cuda'))
time.sleep(0.5)
for epoch in range(epochs):
    running_loss = 0
    train_p_bar = tqdm(range(len(xtrain) // 128))
    for batch_idx in train_p_bar:
        xbatch = xtrain[batch_idx*128:(batch_idx+1)*128].cuda()
        ybatch = ytrain[batch_idx*128:(batch_idx+1)*128].cuda()
        
        pde, y_hat = model.pde_structure(xbatch)
        loss = criteria(pde, torch.zeros(pde.shape).to(pde.device)) + criteria(y_hat, ybatch)
        
        #y_hat = model(xbatch)
        #loss = criteria(y_hat, ybatch)
        
        optimizer.zero_grad()
Ejemplo n.º 14
0
# Setup parameters, model, and optimizer
seed = 0
epochs = 500
lr = 1e-3

criteria = nn.MSELoss()
#model = W_PDE()
model = Model()
model = model.cuda()
#optimizer = optim.LBFGS(model.parameters(), lr=lr)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Trainin loop
print("Optimizing %d parameters on %s" %
      (util.count_parameters(model), 'cuda'))
time.sleep(0.5)
for epoch in range(epochs):
    running_loss = 0
    train_p_bar = tqdm(range(len(xtrain) // 128))
    for batch_idx in train_p_bar:
        xbatch = xtrain[batch_idx * 128:(batch_idx + 1) * 128].cuda()
        ybatch = ytrain[batch_idx * 128:(batch_idx + 1) * 128].cuda()

        #pde, y_hat = model.pde_structure(xbatch)
        #loss = criteria(pde, torch.zeros(pde.shape).to(pde.device)) + criteria(y_hat, ybatch)

        y_hat = model(xbatch)
        loss = criteria(y_hat, ybatch)

        optimizer.zero_grad()
Ejemplo n.º 15
0
                 BIDIRECTIONAL, DROPOUT, PAD_IDX)

    pretrained_embeddings = TEXT.vocab.vectors
    model.embedding.weight.data.copy_(pretrained_embeddings)
    UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
    model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
    model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)
    optimizer = optim.Adam(model.parameters())
    #criterion = nn.CrossEntropyLoss()
    criterion = nn.MSELoss(reduction='mean')

    model = model.to(device)
    criterion = criterion.to(device)

    print('The model has %s trainable parameters' %
          util.count_parameters(model))
    print(pretrained_embeddings.shape)
    print(model.embedding.weight.data)

    print('Just started:')
    valid_loss, valid_acc = util.evaluate(model, valid_iterator, criterion,
                                          labelName)
    print(f'\t Val. Loss: {valid_loss:.3e} |  Val. Acc: {valid_acc*100:.2e}%')
    for epoch in range(N_EPOCHS):
        start_time = time.time()

        train_loss, train_acc = util.train(model, train_iterator, optimizer,
                                           criterion, labelName)
        valid_loss, valid_acc = util.evaluate(model, valid_iterator, criterion,
                                              labelName)
Ejemplo n.º 16
0
    def __init__(
        self,
        d_in=28 * 28,
        d_out=28 * 28,
        d_above=None,
        m=200,
        n=6,
        k=25,
        k_winner_cells=1,
        gamma=0.5,
        eps=0.5,
        activation_fn="tanh",
        decode_activation_fn=None,
        embed_dim=0,
        vocab_size=0,
        decode_from_full_memory=False,
        debug_log_names=None,
        boost_strat="rsm_inhibition",
        x_b_norm=False,
        boost_strength=1.0,
        duty_cycle_period=1000,
        mult_integration=False,
        boost_strength_factor=1.0,
        forget_mu=0.0,
        weight_sparsity=None,
        feedback_conn=False,
        input_bias=False,
        decode_bias=True,
        lateral_conn=True,
        col_output_cells=False,
        debug=False,
        visual_debug=False,
        fpartition=None,
        balance_part_winners=False,
        trainable_decay=False,
        trainable_decay_rec=False,
        max_decay=1.0,
        mem_floor=0.0,
        additive_decay=False,
        stoch_decay=False,
        stoch_k_sd=0.0,
        rec_active_dendrites=0,
        **kwargs,
    ):
        """
        This class includes an attempted replication of the Recurrent Sparse Memory
        architecture suggested by by
        [Rawlinson et al 2019](https://arxiv.org/abs/1905.11589).

        Parameters allow experimentation with a wide array of adjustments to this model,
        both minor and major. Classes of models tested include:

        * "Adjusted" model with k-winners and column boosting, 2 cell winners,
            no inhibition
        * "Flattened" model with 1 cell per column, 1000 cols, 25 winners
            and multiplicative integration of FF & recurrent input
        * "Flat Partitioned" model with 120 winners, and cells partitioned into three
            functional types: ff only, recurrent only, and optionally a region that
            integrates both.

        :param d_in: Dimension of input
        :param m: Number of groups/columns
        :param n: Cells per group/column
        :param k: # of groups/columns to win in topk() (sparsity)
        :param k_winner_cells: # of winning cells per column
        :param gamma: Inhibition decay rate (0-1)
        :param eps: Integrated encoding decay rate (0-1)

        """
        super(RSMLayer, self).__init__()
        self.k = int(k)
        self.k_winner_cells = k_winner_cells
        self.m = m
        self.n = n
        self.gamma = gamma
        self.eps = eps
        self.d_in = d_in
        self.d_out = d_out
        self.d_above = d_above
        self.forget_mu = float(forget_mu)

        self.total_cells = m * n
        self.flattened = self.total_cells == self.m

        # Tweaks
        self.activation_fn = activation_fn
        self.decode_activation_fn = decode_activation_fn
        self.decode_from_full_memory = decode_from_full_memory
        self.boost_strat = boost_strat
        self.x_b_norm = x_b_norm
        self.boost_strength = boost_strength
        self.boost_strength_factor = boost_strength_factor
        self.duty_cycle_period = duty_cycle_period
        self.mult_integration = mult_integration
        self.fpartition = fpartition
        if isinstance(self.fpartition, float):
            # Handle simple single-param FF-percentage only
            # If fpartition is list, interpreted as [ff_pct, rec_pct]
            self.fpartition = [self.fpartition, 1.0 - self.fpartition]
        self.balance_part_winners = balance_part_winners
        self.weight_sparsity = weight_sparsity
        self.feedback_conn = feedback_conn
        self.input_bias = input_bias
        self.decode_bias = decode_bias
        self.lateral_conn = lateral_conn
        self.trainable_decay = trainable_decay
        self.trainable_decay_rec = trainable_decay_rec
        self.max_decay = max_decay
        self.additive_decay = additive_decay
        self.stoch_decay = stoch_decay
        self.col_output_cells = col_output_cells
        self.stoch_k_sd = stoch_k_sd
        self.rec_active_dendrites = rec_active_dendrites
        self.mem_floor = mem_floor

        self.debug = debug
        self.visual_debug = visual_debug
        self.debug_log_names = debug_log_names

        self._build_layers_and_kwinners()

        if self.additive_decay:
            decay_init = torch.ones(self.total_cells,
                                    dtype=torch.float32).uniform_(-3.0, 3.0)
        elif self.stoch_decay:
            # Fixed random decay rates, test with trainable_decay = False
            decay_init = torch.ones(self.total_cells,
                                    dtype=torch.float32).uniform_(-3.0, 3.0)
        else:
            decay_init = self.eps * torch.ones(self.total_cells,
                                               dtype=torch.float32)
        self.decay = nn.Parameter(decay_init,
                                  requires_grad=self.trainable_decay)
        self.register_parameter("decay", self.decay)
        self.learning_iterations = 0
        self.register_buffer("duty_cycle", torch.zeros(self.total_cells))

        print("Created %s with %d trainable params" %
              (str(self), count_parameters(self)))
Ejemplo n.º 17
0
    def __init__(self,
                 data_dir=None,
                 file_extension='png',
                 is_training=True,
                 learning_rate=0.0002,
                 beta1=0.9,
                 reconstr_weight=0.85,
                 smooth_weight=0.05,
                 object_depth_weight=0.0,
                 object_depth_threshold=0.01,
                 exclude_object_mask=True,
                 stop_egomotion_gradient=True,
                 ssim_weight=0.15,
                 batch_size=4,
                 img_height=128,
                 img_width=416,
                 seq_length=3,
                 architecture=nets.RESNET,
                 imagenet_norm=True,
                 weight_reg=0.05,
                 exhaustive_mode=False,
                 random_scale_crop=False,
                 flipping_mode=reader.FLIP_RANDOM,
                 random_color=True,
                 depth_upsampling=True,
                 depth_normalization=True,
                 compute_minimum_loss=True,
                 use_skip=True,
                 use_axis_angle=False,
                 joint_encoder=True,
                 build_sum=True,
                 shuffle=True,
                 input_file='train',
                 handle_motion=False,
                 equal_weighting=False,
                 same_trans_rot_scaling=True,
                 residual_deformer=True,
                 seg_align_type='null',
                 use_rigid_residual_flow=True,
                 region_deformer_scaling=1.0):
        self.data_dir = data_dir
        self.file_extension = file_extension
        self.is_training = is_training
        self.learning_rate = learning_rate
        self.reconstr_weight = reconstr_weight
        self.smooth_weight = smooth_weight
        self.ssim_weight = ssim_weight
        self.object_depth_weight = object_depth_weight
        self.object_depth_threshold = object_depth_threshold
        self.exclude_object_mask = exclude_object_mask
        self.beta1 = beta1
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.seq_length = seq_length
        self.architecture = architecture
        self.imagenet_norm = imagenet_norm
        self.weight_reg = weight_reg
        self.exhaustive_mode = exhaustive_mode
        self.random_scale_crop = random_scale_crop
        self.flipping_mode = flipping_mode
        self.random_color = random_color
        self.depth_upsampling = depth_upsampling
        self.depth_normalization = depth_normalization
        self.compute_minimum_loss = compute_minimum_loss
        self.use_skip = use_skip
        self.joint_encoder = joint_encoder
        self.build_sum = build_sum
        self.shuffle = shuffle
        self.input_file = input_file
        self.handle_motion = handle_motion
        self.equal_weighting = equal_weighting
        self.same_trans_rot_scaling = same_trans_rot_scaling
        self.residual_deformer = residual_deformer
        self.seg_align_type = seg_align_type
        self.use_rigid_residual_flow = use_rigid_residual_flow
        self.region_deformer_scaling = region_deformer_scaling
        self.stop_egomotion_gradient = stop_egomotion_gradient
        self.use_axis_angle = use_axis_angle

        self.trans_params_size = 32  # parameters of the bicubic function

        logging.info('data_dir: %s', data_dir)
        logging.info('file_extension: %s', file_extension)
        logging.info('is_training: %s', is_training)
        logging.info('learning_rate: %s', learning_rate)
        logging.info('reconstr_weight: %s', reconstr_weight)
        logging.info('smooth_weight: %s', smooth_weight)
        logging.info('ssim_weight: %s', ssim_weight)
        logging.info('beta1: %s', beta1)
        logging.info('batch_size: %s', batch_size)
        logging.info('img_height: %s', img_height)
        logging.info('img_width: %s', img_width)
        logging.info('seq_length: %s', seq_length)
        logging.info('architecture: %s', architecture)
        logging.info('imagenet_norm: %s', imagenet_norm)
        logging.info('weight_reg: %s', weight_reg)
        logging.info('exhaustive_mode: %s', exhaustive_mode)
        logging.info('random_scale_crop: %s', random_scale_crop)
        logging.info('flipping_mode: %s', flipping_mode)
        logging.info('random_color: %s', random_color)
        logging.info('depth_upsampling: %s', depth_upsampling)
        logging.info('depth_normalization: %s', depth_normalization)
        logging.info('compute_minimum_loss: %s', compute_minimum_loss)
        logging.info('use_skip: %s', use_skip)
        logging.info('joint_encoder: %s', joint_encoder)
        logging.info('build_sum: %s', build_sum)
        logging.info('shuffle: %s', shuffle)
        logging.info('input_file: %s', input_file)
        logging.info('handle_motion: %s', handle_motion)
        logging.info('equal_weighting: %s', equal_weighting)

        if self.is_training:
            self.reader = reader.DataReader(self.data_dir, self.batch_size,
                                            self.img_height, self.img_width,
                                            self.seq_length, NUM_SCALES,
                                            self.file_extension,
                                            self.random_scale_crop,
                                            self.flipping_mode,
                                            self.random_color,
                                            self.imagenet_norm,
                                            self.shuffle,
                                            self.input_file,
                                            self.seg_align_type)
            self.build_train_graph()
        else:
            self.build_depth_test_graph()
            self.build_egomotion_test_graph()

        # At this point, the model is ready. Print some info on model params.
        util.count_parameters()
def train():
    criterion_mse = nn.MSELoss()

    param, det_size, _3D_vol, CT_vol, ray_proj_mov, corner_pt, norm_factor = input_param(
        CT_PATH, SEG_PATH, BATCH_SIZE, VOX_SPAC, zFlip)

    initmodel = ProST_init(param).to(device)
    model = RegiNet(param, det_size).to(device)

    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    scheduler = CyclicLR(optimizer,
                         base_lr=1e-6,
                         max_lr=1e-4,
                         step_size_up=100)

    if RESUME_EPOCH >= 0:
        print('Resuming model from epoch', RESUME_EPOCH)
        checkpoint = torch.load(RESUME_MODEL)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        START_EPOCH = RESUME_EPOCH + 1
        step_cnt = RESUME_EPOCH * ITER_NUM
    else:
        START_EPOCH = 0
        step_cnt = 0

    print('module parameters:', count_parameters(model))

    model.train()

    riem_grad_loss_list = []
    riem_grad_rot_loss_list = []
    riem_grad_trans_loss_list = []
    riem_dist_list = []
    riem_dist_mean_list = []
    mse_loss_list = []
    vecgrad_diff_list = []
    total_loss_list = []

    for epoch in range(START_EPOCH, 20000):
        ## Do Iterative Validation
        model.train()
        for iter in tqdm(range(ITER_NUM)):
            step_cnt = step_cnt + 1
            scheduler.step()
            # Get target  projection
            transform_mat3x4_gt, rtvec, rtvec_gt = init_rtvec_train(
                BATCH_SIZE, device)

            with torch.no_grad():
                target = initmodel(CT_vol, ray_proj_mov, transform_mat3x4_gt,
                                   corner_pt)
                min_tar, _ = torch.min(target.reshape(BATCH_SIZE, -1),
                                       dim=-1,
                                       keepdim=True)
                max_tar, _ = torch.max(target.reshape(BATCH_SIZE, -1),
                                       dim=-1,
                                       keepdim=True)
                target = (target.reshape(BATCH_SIZE, -1) -
                          min_tar) / (max_tar - min_tar)
                target = target.reshape(BATCH_SIZE, 1, det_size, det_size)

            # Do Projection and get two encodings
            encode_mov, encode_tar, proj_mov = model(_3D_vol, target, rtvec,
                                                     corner_pt)

            optimizer.zero_grad()
            # Calculate Net l2 Loss, L_N
            l2_loss = criterion_mse(encode_mov, encode_tar)

            # Find geodesic distance
            riem_dist = np.sqrt(
                riem.loss(rtvec.detach().cpu(),
                          rtvec_gt.detach().cpu(), METRIC))

            z = Variable(torch.ones(l2_loss.shape)).to(device)
            rtvec_grad = torch.autograd.grad(l2_loss,
                                             rtvec,
                                             grad_outputs=z,
                                             only_inputs=True,
                                             create_graph=True,
                                             retain_graph=True)[0]
            # Find geodesic gradient
            riem_grad = riem.grad(rtvec.detach().cpu(),
                                  rtvec_gt.detach().cpu(), METRIC)
            riem_grad = torch.tensor(riem_grad,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=device)

            ### Translation Loss
            riem_grad_transnorm = riem_grad[:, 3:] / (
                torch.norm(riem_grad[:, 3:], dim=-1, keepdim=True) + EPS)
            rtvec_grad_transnorm = rtvec_grad[:, 3:] / (
                torch.norm(rtvec_grad[:, 3:], dim=-1, keepdim=True) + EPS)
            riem_grad_trans_loss = torch.mean(
                torch.sum((riem_grad_transnorm - rtvec_grad_transnorm)**2,
                          dim=-1))

            ### Rotation Loss
            riem_grad_rotnorm = riem_grad[:, :3] / (
                torch.norm(riem_grad[:, :3], dim=-1, keepdim=True) + EPS)
            rtvec_grad_rotnorm = rtvec_grad[:, :3] / (
                torch.norm(rtvec_grad[:, :3], dim=-1, keepdim=True) + EPS)
            riem_grad_rot_loss = torch.mean(
                torch.sum((riem_grad_rotnorm - rtvec_grad_rotnorm)**2, dim=-1))

            riem_grad_loss = riem_grad_trans_loss + riem_grad_rot_loss

            riem_grad_loss.backward()

            # Clip training gradient magnitude
            torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value)
            optimizer.step()

            total_loss = l2_loss

            mse_loss_list.append(torch.mean(l2_loss).detach().item())
            riem_grad_loss_list.append(riem_grad_loss.detach().item())
            riem_grad_rot_loss_list.append(riem_grad_rot_loss.detach().item())
            riem_grad_trans_loss_list.append(
                riem_grad_trans_loss.detach().item())
            riem_dist_list.append(riem_dist)
            riem_dist_mean_list.append(np.mean(riem_dist))
            total_loss_list.append(total_loss.detach().item())
            vecgrad_diff = (rtvec_grad - riem_grad).detach().cpu().numpy()
            vecgrad_diff_list.append(vecgrad_diff)

            torch.cuda.empty_cache()

            cur_lr = float(scheduler.get_lr()[0])

            print('Train epoch: {} Iter: {} tLoss: {:.4f}, gLoss: {:.4f}/{:.2f}, gLoss_rot: {:.4f}/{:.2f}, gLoss_trans: {:.4f}/{:.2f}, LR: {:.4f}'.format(
                        epoch, iter, np.mean(total_loss_list), np.mean(riem_grad_loss_list), np.std(riem_grad_loss_list),\
                                     np.mean(riem_grad_rot_loss_list), np.std(riem_grad_rot_loss_list),\
                                     np.mean(riem_grad_trans_loss_list), np.std(riem_grad_trans_loss_list),
                        cur_lr, sys.stdout))

        if epoch % SAVE_MODEL_EVERY_EPOCH == 0:
            state = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                SAVE_PATH + '/checkpoint/vali_model' + str(epoch) + '.pt')

        riem_grad_loss_list = []
        riem_grad_rot_loss_list = []
        riem_grad_trans_loss_list = []
        riem_dist_list = []
        riem_dist_mean_list = []
        mse_loss_list = []
        vecgrad_diff_list = []
Ejemplo n.º 19
0
def main():
    np.random.seed(0)
    torch.manual_seed(0)

    logger.info('Loading data...')
    train_loader, val_loader, classes = custom_dataset.load_data(args)

    # override autodetect if n_classes is given
    if args.n_classes > 0:
        classes = np.arange(args.n_classes)

    model = load_model(classes)

    logger.info('Loaded model; params={}'.format(util.count_parameters(model)))
    if not args.cpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    model.to(device)
    cudnn.benchmark = True
    logger.info('Running on ' + str(device))

    summary_writer = Logger(args.logdir)

    # Loss and Optimizer
    n_epochs = args.epochs
    if args.label_smoothing > 0:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    train_state = init_train_state()
    # freeze layers
    for l in args.freeze_layers:
        for p in getattr(model, l).parameters():
            p.requires_grad = False
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=train_state['lr'],
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'nesterov':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=train_state['lr'],
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    # this is used to warm-start
    if args.warm_start_from:
        logger.info('Warm-starting from {}'.format(args.warm_start_from))
        assert os.path.isfile(args.warm_start_from)
        train_state = load_checkpoint(args.warm_start_from, model, optimizer)
        logger.info('Params loaded.')
        # do not override train_state these when warm staring
        train_state = init_train_state()

    ckptfile = str(Path(args.logdir) / args.latest_fname)
    if os.path.isfile(ckptfile):
        logger.info('Loading checkpoint: {}'.format(ckptfile))
        train_state = load_checkpoint(ckptfile, model, optimizer)
        logger.info('Params loaded.')
    else:
        logger.info('Checkpoint {} not found; ignoring.'.format(ckptfile))

    # Training / Eval loop
    epoch_time = []                 # store time per epoch
    # we save epoch+1 to checkpoints; but for eval we should repeat prev. epoch
    if args.skip_train:
        train_state['start_epoch'] -= 1
    for epoch in range(train_state['start_epoch'], n_epochs):
        logger.info('Epoch: [%d/%d]' % (epoch + 1, n_epochs))
        start = time.time()

        if not args.skip_train:
            model.train()
            train(train_loader, device, model, criterion, optimizer, summary_writer, train_state,
                  n_classes=len(classes))
            logger.info('Time taken: %.2f sec...' % (time.time() - start))
            if epoch == 0:
                train_state['steps_epoch'] = train_state['step']
        # always eval on last epoch
        if not args.skip_eval or epoch == n_epochs - 1:
            logger.info('\n Starting evaluation...')
            model.eval()
            eval_shrec = True if epoch == n_epochs - 1 and args.retrieval_dir else False
            metrics, inputs = eval(
                val_loader, device, model, criterion, eval_shrec)

            logger.info('\tcombined: %.2f, Acc: %.2f, mAP: %.2f, Loss: %.4f' %
                        (metrics['combined'],
                         metrics['acc_inst'],
                         metrics.get('mAP_inst', 0.),
                         metrics['loss']))

            # Log epoch to tensorboard
            # See log using: tensorboard --logdir='logs' --port=6006
            ims = get_summary_ims(inputs)
            if not args.nolog:
                util.logEpoch(summary_writer, model, epoch + 1, metrics, ims)
        else:
            metrics = None

        # Decaying Learning Rate
        if args.lr_decay_mode == 'step':
            if (epoch + 1) % args.lr_decay_freq == 0:
                train_state['lr'] *= args.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = train_state['lr']

        # Save model
        if not args.skip_train:
            logger.info('\tSaving latest model')
            util.save_checkpoint({
                'epoch': epoch + 1,
                'step': train_state['step'],
                'steps_epoch': train_state['steps_epoch'],
                'state_dict': model.state_dict(),
                'metrics': metrics,
                'optimizer': optimizer.state_dict(),
                'lr': train_state['lr'],
            },
                str(Path(args.logdir) / args.latest_fname))

        total_epoch_time = time.time() - start
        epoch_time.append(total_epoch_time)
        logger.info('Total time for this epoch: {} s'.format(total_epoch_time))

        # if last epoch, show eval results
        if epoch == n_epochs - 1:
            logger.info(
                '|model|combined|acc inst|acc cls|mAP inst|mAP cls|loss|')
            logger.info('|{}|{:.2f}|{:.2f}|{:.2f}|{:.2f}|{:.2f}|{:.4f}|'
                        .format(os.path.basename(args.logdir),
                                metrics['combined'],
                                metrics['acc_inst'],
                                metrics['acc_cls'],
                                metrics.get('mAP_inst', 0.),
                                metrics.get('mAP_cls', 0.),
                                metrics['loss']))

        if args.skip_train:
            # if evaluating, run it once
            break

        if time.perf_counter() + np.max(epoch_time) > start_time + args.exit_after:
            logger.info('Next epoch will likely exceed alotted time; exiting...')
            break
Ejemplo n.º 20
0
def main():
    train_chairs = [
        'chair_0001', 'chair_0005', 'chair_0101', 'chair_0084', 'chair_0497',
        'chair_0724', 'chair_0878'
    ]
    test_chairs = ['chair_0957']
    features = []
    np.random.seed(0)
    torch.manual_seed(0)

    logger.info('Loading data...')
    train_loader, val_loader, classes = custom_dataset.load_data(args)

    # override autodetect if n_classes is given
    if args.n_classes > 0:
        classes = np.arange(args.n_classes)

    model = load_model(classes)

    logger.info('Loaded model; params={}'.format(util.count_parameters(model)))
    if not args.cpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    model.to(device)
    cudnn.benchmark = True
    logger.info('Running on ' + str(device))

    summary_writer = Logger(args.logdir)

    # Loss and Optimizer
    n_epochs = args.epochs
    if args.label_smoothing > 0:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    train_state = init_train_state()
    # freeze layers
    for l in args.freeze_layers:
        for p in getattr(model, l).parameters():
            p.requires_grad = False
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=train_state['lr'],
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'nesterov':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=train_state['lr'],
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    # this is used to warm-start
    if args.warm_start_from:
        logger.info('Warm-starting from {}'.format(args.warm_start_from))
        assert os.path.isfile(args.warm_start_from)
        train_state = load_checkpoint(args.warm_start_from, model, optimizer)
        logger.info('Params loaded.')
        # do not override train_state these when warm staring
        train_state = init_train_state()

    ckptfile = str(Path(args.logdir) / args.latest_fname)
    if os.path.isfile(ckptfile):
        logger.info('Loading checkpoint: {}'.format(ckptfile))
        train_state = load_checkpoint(ckptfile, model, optimizer)
        logger.info('Params loaded.')
    else:
        logger.info('Checkpoint {} not found; ignoring.'.format(ckptfile))

    # Training / Eval loop
    epoch_time = []  # store time per epoch
    # we save epoch+1 to checkpoints; but for eval we should repeat prev. epoch

    if args.skip_train:
        train_state['start_epoch'] -= 1
    for epoch in range(0, n_epochs):

        logger.info('Epoch: [%d/%d]' % (epoch + 1, n_epochs))
        start = time.time()

        if not args.skip_train:
            model.train()

            if epoch == n_epochs - 1:
                features = train(train_loader,
                                 device,
                                 model,
                                 criterion,
                                 optimizer,
                                 summary_writer,
                                 train_state,
                                 1,
                                 train_chairs,
                                 n_classes=len(classes))

                PIK = "descriptors.dat"
                with open(PIK, "wb") as f:
                    pickle.dump(train_desc, f)

            else:

                train(train_loader,
                      device,
                      model,
                      criterion,
                      optimizer,
                      summary_writer,
                      train_state,
                      0,
                      train_chairs,
                      n_classes=len(classes))

            logger.info('Time taken: %.2f sec...' % (time.time() - start))
            if epoch == 0:
                train_state['steps_epoch'] = train_state['step']
        # always eval on last epoch
        if not args.skip_eval or epoch == n_epochs + 1:
            #print("-------------SAVING MODEL----------------");
            #torch.save(model,"saved.pth")
            logger.info('\n Starting evaluation...')
            model.eval()
            eval_shrec = True if epoch == n_epochs - 1 and args.retrieval_dir else False
            metrics, inputs = eval(val_loader, device, model, criterion,
                                   eval_shrec, 0, test_chairs, features)

            logger.info('\tcombined: %.2f, Acc: %.2f, mAP: %.2f, Loss: %.4f' %
                        (metrics['combined'], metrics['acc_inst'],
                         metrics.get('mAP_inst', 0.), metrics['loss']))

            # Log epoch to tensorboard
            # See log using: tensorboard --logdir='logs' --port=6006
            ims = get_summary_ims(inputs)
            if not args.nolog:
                util.logEpoch(summary_writer, model, epoch + 1, metrics, ims)
        else:
            metrics = None

        # Decaying Learning Rate
        if args.lr_decay_mode == 'step':
            if (epoch + 1) % args.lr_decay_freq == 0:
                train_state['lr'] *= args.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = train_state['lr']

        # Save model
        if not args.skip_train:
            logger.info('\tSaving latest model')
            util.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'step': train_state['step'],
                    'steps_epoch': train_state['steps_epoch'],
                    'state_dict': model.state_dict(),
                    'metrics': metrics,
                    'optimizer': optimizer.state_dict(),
                    'lr': train_state['lr'],
                }, str(Path(args.logdir) / args.latest_fname))

        total_epoch_time = time.time() - start
        epoch_time.append(total_epoch_time)
        logger.info('Total time for this epoch: {} s'.format(total_epoch_time))

        if args.skip_train:
            # if evaluating, run it once
            break

        if time.perf_counter() + np.max(
                epoch_time) > start_time + args.exit_after:
            logger.info(
                'Next epoch will likely exceed alotted time; exiting...')
            break

    print("Encoder training done")
    print("Now training the Decoder")

    ###############################Decoder ###########################################

    decoder = models.Decoder()
    print(decoder)
    decoder.to(device)

    train_state = init_train_state()

    crit = nn.MSELoss()

    optim = torch.optim.SGD(decoder.parameters(),
                            lr=train_state['lr'],
                            momentum=args.momentum,
                            weight_decay=args.weight_decay,
                            nesterov=True)

    path = str("/home/smjadhav/Research/emvn/decoder_model/latest.pth.tar")
    if os.path.isfile(path):
        logger.info('Loading decoder checkpoint: {}'.format(path))
        train_state = load_checkpoint(path, decoder, optimizer)
        logger.info('Params loaded.')
    else:
        print("Decoder model not found")

    train_size = len(train_loader)
    metrics = {}

    for epoch in range(0, 50):
        print("Epoch ", epoch + 1)
        decoder.train()

        PIK = "D1.dat"

        with open(PIK, "rb") as f:
            try:
                i = 0

                while (True):

                    data = pickle.load(f)

                    inputs = torch.from_numpy(data[1]).to(device)
                    target_img = torch.from_numpy(data[0]).to(device)
                    outputs = decoder(inputs)

                    optim.zero_grad()
                    loss = crit(outputs, target_img)
                    loss.backward()
                    optim.step()

                    if args.lr_decay_mode == 'cos':
                        # estimate steps_epoch from first epoch (we may have dropped entries)
                        steps_epoch = (train_state['steps_epoch']
                                       if train_state['steps_epoch'] > 0 else
                                       len(train_loader))
                        # TODO: there will be a jump here if many entries are dropped
                        #       and we only figure out # of steps after first epoch

                        if train_state['step'] < steps_epoch:
                            train_state['lr'] = args.lr * train_state[
                                'step'] / steps_epoch
                        else:
                            nsteps = steps_epoch * args.epochs
                            train_state['lr'] = (0.5 * args.lr * (1 + np.cos(
                                train_state['step'] * np.pi / nsteps)))
                        for param_group in optim.param_groups:
                            param_group['lr'] = train_state['lr']

                    if (i + 1) % args.print_freq == 0:
                        print("\tIter [%d/%d] Loss: %.4f" %
                              (i + 1, train_size, loss.item()))

                    if args.max_steps > 0 and i > args.max_steps:
                        break
                    i = i + 1

            except:
                exit

        if ((epoch + 1) % 5 == 0):
            print("Saving Decoder model")
            util.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'step': train_state['step'],
                    'steps_epoch': train_state['steps_epoch'],
                    'state_dict': decoder.state_dict(),
                    'metrics': metrics,
                    'optimizer': optimizer.state_dict(),
                    'lr': train_state['lr'],
                }, path)
            PIK = "images.dat"
            with open(PIK, "wb") as f:
                pickle.dump(outputs, f)
Ejemplo n.º 21
0
  def __init__(self,
               data_dir=None,
               file_extension='png',
               is_training=True,
               learning_rate=0.0002,
               beta1=0.9,
               reconstr_weight=0.85,
               smooth_weight=0.05,
               ssim_weight=0.15,
               icp_weight=0.0,
               batch_size=4,
               img_height=128,
               img_width=416,
               seq_length=3,
               architecture=nets.RESNET,
               imagenet_norm=True,
               weight_reg=0.05,
               exhaustive_mode=False,
               random_scale_crop=False,
               flipping_mode=reader.FLIP_RANDOM,
               random_color=True,
               depth_upsampling=True,
               depth_normalization=True,
               compute_minimum_loss=True,
               use_skip=True,
               joint_encoder=True,
               build_sum=True,
               shuffle=True,
               input_file='train',
               handle_motion=False,
               equal_weighting=False,
               size_constraint_weight=0.0,
               train_global_scale_var=True):
    self.data_dir = data_dir
    self.file_extension = file_extension
    self.is_training = is_training
    self.learning_rate = learning_rate
    self.reconstr_weight = reconstr_weight
    self.smooth_weight = smooth_weight
    self.ssim_weight = ssim_weight
    self.icp_weight = icp_weight
    self.beta1 = beta1
    self.batch_size = batch_size
    self.img_height = img_height
    self.img_width = img_width
    self.seq_length = seq_length
    self.architecture = architecture
    self.imagenet_norm = imagenet_norm
    self.weight_reg = weight_reg
    self.exhaustive_mode = exhaustive_mode
    self.random_scale_crop = random_scale_crop
    self.flipping_mode = flipping_mode
    self.random_color = random_color
    self.depth_upsampling = depth_upsampling
    self.depth_normalization = depth_normalization
    self.compute_minimum_loss = compute_minimum_loss
    self.use_skip = use_skip
    self.joint_encoder = joint_encoder
    self.build_sum = build_sum
    self.shuffle = shuffle
    self.input_file = input_file
    self.handle_motion = handle_motion
    self.equal_weighting = equal_weighting
    self.size_constraint_weight = size_constraint_weight
    self.train_global_scale_var = train_global_scale_var

    logging.info('data_dir: %s', data_dir)
    logging.info('file_extension: %s', file_extension)
    logging.info('is_training: %s', is_training)
    logging.info('learning_rate: %s', learning_rate)
    logging.info('reconstr_weight: %s', reconstr_weight)
    logging.info('smooth_weight: %s', smooth_weight)
    logging.info('ssim_weight: %s', ssim_weight)
    logging.info('icp_weight: %s', icp_weight)
    logging.info('size_constraint_weight: %s', size_constraint_weight)
    logging.info('beta1: %s', beta1)
    logging.info('batch_size: %s', batch_size)
    logging.info('img_height: %s', img_height)
    logging.info('img_width: %s', img_width)
    logging.info('seq_length: %s', seq_length)
    logging.info('architecture: %s', architecture)
    logging.info('imagenet_norm: %s', imagenet_norm)
    logging.info('weight_reg: %s', weight_reg)
    logging.info('exhaustive_mode: %s', exhaustive_mode)
    logging.info('random_scale_crop: %s', random_scale_crop)
    logging.info('flipping_mode: %s', flipping_mode)
    logging.info('random_color: %s', random_color)
    logging.info('depth_upsampling: %s', depth_upsampling)
    logging.info('depth_normalization: %s', depth_normalization)
    logging.info('compute_minimum_loss: %s', compute_minimum_loss)
    logging.info('use_skip: %s', use_skip)
    logging.info('joint_encoder: %s', joint_encoder)
    logging.info('build_sum: %s', build_sum)
    logging.info('shuffle: %s', shuffle)
    logging.info('input_file: %s', input_file)
    logging.info('handle_motion: %s', handle_motion)
    logging.info('equal_weighting: %s', equal_weighting)
    logging.info('train_global_scale_var: %s', train_global_scale_var)

    if self.size_constraint_weight > 0 or not is_training:
      self.global_scale_var = tf.Variable(
          0.1, name='global_scale_var',
          trainable=self.is_training and train_global_scale_var,
          dtype=tf.float32,
          constraint=lambda x: tf.clip_by_value(x, 0, np.infty))

    if self.is_training:
      self.reader = reader.DataReader(self.data_dir, self.batch_size,
                                      self.img_height, self.img_width,
                                      self.seq_length, NUM_SCALES,
                                      self.file_extension,
                                      self.random_scale_crop,
                                      self.flipping_mode,
                                      self.random_color,
                                      self.imagenet_norm,
                                      self.shuffle,
                                      self.input_file)
      self.build_train_graph()
    else:
      self.build_depth_test_graph()
      self.build_egomotion_test_graph()
      self.build_objectmotion_test_graph()
      if self.handle_motion:
        self.build_objectmotion_test_graph()

    # At this point, the model is ready. Print some info on model params.
    util.count_parameters()
Ejemplo n.º 22
0
                           help="Hidden dim")
    optparser.add_argument(
        "-s",
        "--cs",
        dest="cs",
        action="store_true",
        default=False,
        help="Reconnect cell state",
    )
    opts = optparser.parse_args()

    model = baseline_models.LSTMModel(vocab_size=10,
                                      nhid=opts.nhid,
                                      d_in=28**2,
                                      d_out=28**2)
    print("LSTM parameters: ", util.count_parameters(model))

    dataset = rsm_samplers.MNISTBufferedDataset(
        expanduser("~/nta/datasets"),
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    )

    sampler = rsm_samplers.MNISTSequenceSampler(
        dataset,
        sequences=PAGI9,
        batch_size=BSZ,
        noise_buffer=opts.noise,