Exemplo n.º 1
0
def test(rank,world_size,args):
    # Info
    hostname = socket.gethostname().split('.')[0] # for printing
    print("Started process on node: ", hostname)

    # Model
    model_out = 'pred' # output is always pred for mse loss
    device = 'cpu' # cpu only
    torch.manual_seed(args.seed) # all processes start with the same model
    if args.model_type == 'PredNet':
        model = PredNet(args.in_channels,args.stack_sizes,args.R_stack_sizes,
                        args.A_kernel_sizes,args.Ahat_kernel_sizes,
                        args.R_kernel_sizes,args.use_satlu,args.pixel_max,
                        args.Ahat_act,args.satlu_act,args.error_act,
                        args.LSTM_act,args.LSTM_c_act,args.bias,
                        args.use_1x1_out,args.FC,model_out,device)
    elif args.model_type == 'ConvLSTM':
        model = ConvLSTM(args.in_channels,args.hidden_channels,args.kernel_size,
                         args.LSTM_act,args.LSTM_c_act,args.out_act,
                         args.bias,args.FC,device)
    # Load from checkpoint
    if args.checkpoint_path is not None:
        model.load_state_dict(torch.load(args.checkpoint_path))
    else:
        print("Must include checkpoint_path argument to test")
    model.eval()

    # Data
    if args.dataset == 'KITTI':
        dataset = KITTI(args.val_data_path,args.val_sources_path,args.seq_len)
    elif args.dataset == 'CCN':
        dataset = CCN(args.val_data_path,args.seq_len)
    partitioner = DataPartitioner(dataset, world_size)
    partition = partitioner.get_partition(rank)
    val_loader = DataLoader(partition, args.batch_size,
                              shuffle=True,num_workers=1,pin_memory=False)
    if rank == 0:
        print("Val dataset has %d samples total" % len(dataset))
    print("%s: Partition of val dataset has %d samples" % (hostname,
                                                           len(partition)))

    # Loss function: always use mse for testing
    mse_loss = nn.MSELoss()

    # Test model on partition
    with torch.no_grad():
        losses = []
        for X in val_loader:
            # Forward
            X = X.to(device)
            output = model(X)
            # Compute loss
            X_no_t0 = X[:,1:,:,:,:]
            loss = mse_loss(output,X_no_t0)
            # Record loss
            loss_datapoint = loss.data.item()
            losses.append(loss_datapoint)
    print("Average MSE: ", np.mean(losses))
Exemplo n.º 2
0
    def init_graph_conv_lstm(self):
        input_size = self.args['edge_feature_size']
        hidden_size = self.args['link_hidden_size']
        hidden_layers = self.args['link_hidden_layers']

        self.ConvLSTM = ConvLSTM.ConvLSTM(input_size, hidden_size,
                                          hidden_layers)
        self.learn_modules.append(torch.nn.Conv2d(hidden_size, 1, 1))
        self.learn_modules.append(torch.nn.Sigmoid())
    def __init__(self):
        super(Single_Scale_Recurrent_Network, self).__init__()
        self.in1 = nn.Conv2d(3, 32, 5, 1, 2)
        self.in2 = nn.Conv2d(6, 32, 5, 1, 2)

        self.convlstm = ConvLSTM(input_size=(64, 64),
                                 input_dim=128,
                                 hidden_dim=128,
                                 kernel_size=(5, 5),
                                 num_layers=1,
                                 batch_first=True,
                                 bias=True,
                                 return_all_layers=False)
        self.inblock = InBlock()
        self.eblock1 = EBlock(32)
        self.eblock2 = EBlock(64)
        self.dblock1 = DBlock(128)
        self.dblock2 = DBlock(64)
        self.outblock = OutBlock()
Exemplo n.º 4
0
    def _build_net(self):
        # ConvLSTM2D
        rnn_out, last_hidden = ConvLSTM.convlstm2d_rnn(
            rnn_input=self.input,
            init_hidden=None,
            init_cell=None,
            padding=1,
            hidden_h=self.h,
            hidden_w=self.w,
            filters=self.filters,
            filter_size=self.filter_size,
            sequence_length=self.input_seqlen)

        # Batch Norm
        bn = layers.layer_norm(rnn_out, begin_norm_axis=4)

        # ConvLSTM2D
        rnn_out, last_hidden = ConvLSTM.convlstm2d_rnn(
            rnn_input=bn,
            init_hidden=None,
            init_cell=None,
            padding=1,
            hidden_h=self.h,
            hidden_w=self.w,
            filters=self.filters,
            filter_size=self.filter_size,
            sequence_length=self.input_seqlen)

        # Batch Norm
        bn = layers.layer_norm(rnn_out, begin_norm_axis=4)

        # ConvLSTM2D
        rnn_out, last_hidden = ConvLSTM.convlstm2d_rnn(
            rnn_input=bn,
            init_hidden=None,
            init_cell=None,
            padding=1,
            hidden_h=self.h,
            hidden_w=self.w,
            filters=self.filters,
            filter_size=self.filter_size,
            sequence_length=self.input_seqlen)

        # Batch Norm
        bn = layers.layer_norm(rnn_out, begin_norm_axis=4)

        # ConvLSTM2D
        rnn_out, last_hidden = ConvLSTM.convlstm2d_rnn(
            rnn_input=bn,
            init_hidden=None,
            init_cell=None,
            padding=1,
            hidden_h=self.h,
            hidden_w=self.w,
            filters=self.filters,
            filter_size=self.filter_size,
            sequence_length=self.input_seqlen)

        # Batch Norm
        bn = layers.layer_norm(rnn_out, begin_norm_axis=4)

        # Transpose : (batch x C x D x H x W)
        tr = layers.transpose(bn, [0, 4, 1, 2, 3])

        # Conv3D
        conv3d = layers.conv3d(input=tr,
                               num_filters=2,
                               filter_size=3,
                               padding=1)
        # conv3d : (batch x C x D x H x W)

        conv3d = layers.transpose(conv3d, [0, 2, 3, 4, 1])
        # conv3d: (batch x D x H x W x C)

        return conv3d
Exemplo n.º 5
0
    return data_x.astype(np.float32), data_y.reshape(len(data_y)).astype(np.uint8)


#(557963,113)  (118750,113)
X_train, y_train, X_test, y_test = load_dataset('oppChallenge_gestures.data')
assert NB_SENSOR_CHANNELS == X_train.shape[1]
X_train, y_train = opp_sliding_window(X_train, y_train, SLIDING_WINDOW_LENGTH, SLIDING_WINDOW_STEP)
X_test, y_test = opp_sliding_window(X_test, y_test, SLIDING_WINDOW_LENGTH, SLIDING_WINDOW_STEP)  #(9894, 24, 113) 

kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(MyDataset(X_train, y_train), 
                        batch_size=BATCH_SIZE, shuffle=False, **kwargs)
test_loader = torch.utils.data.DataLoader(MyDataset(X_test, y_test), 
                        batch_size=BATCH_SIZE, shuffle=False, **kwargs)

model = ConvLSTM.ConvLSTM(num_classes=18)
model = torch.load('model.pkl')

model = model.to(DEVICE)
optimizer = optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        momentum=MOMEMTUN,
        weight_decay=L2_WEIGHT
        )
print(model)

#train
for e in tqdm(range(1, EPOCH + 1)):
    model = train(model=model, optimizer=optimizer,
            #epoch=e, data_src=[X_train,y_train], data_tar=[X_test,y_test])
Exemplo n.º 6
0
def train(rank, world_size, args):

    # Info
    hostname = socket.gethostname().split('.')[0] # for printing
    print("Started process on node: ", hostname)

    # Model
    model_out = 'error' if args.loss == 'E' else 'pred'
    device = 'cpu' # cpu only
    torch.manual_seed(args.seed) # all processes start with the same model
    if args.model_type == 'PredNet':
        model = PredNet(args.in_channels,args.stack_sizes,args.R_stack_sizes,
                        args.A_kernel_sizes,args.Ahat_kernel_sizes,
                        args.R_kernel_sizes,args.use_satlu,args.pixel_max,
                        args.Ahat_act,args.satlu_act,args.error_act,
                        args.LSTM_act,args.LSTM_c_act,args.bias,
                        args.use_1x1_out,args.FC,args.send_acts,args.no_ER,
                        args.RAhat,args.local_grad,model_out,device)
    elif args.model_type == 'MultiConvLSTM':
        model = MultiConvLSTM(args.in_channels,args.R_stack_sizes,
                              args.R_kernel_sizes,args.use_satlu,args.pixel_max,
                              args.Ahat_act,args.satlu_act,args.error_act,
                              args.LSTM_act,args.LSTM_c_act,args.bias,
                              args.use_1x1_out,args.FC,args.local_grad,
                              model_out,device)
    elif args.model_type == 'ConvLSTM':
        model = ConvLSTM(args.in_channels,args.hidden_channels,args.kernel_size,
                         args.LSTM_act,args.LSTM_c_act,args.out_act,
                         args.bias,args.FC,device)

    if args.load_weights_from is not None:
        model.load_state_dict(torch.load(args.load_weights_from))
    model.train()

    # Data
    if args.dataset == 'KITTI':
        dataset = KITTI(args.train_data_path,args.train_sources_path,
                           args.seq_len)
    elif args.dataset == 'CCN':
        dataset = CCN(args.train_data_path,args.seq_len)
    partitioner = DataPartitioner(dataset, world_size)
    partition = partitioner.get_partition(rank)
    train_loader = DataLoader(partition, args.batch_size,
                              shuffle=True,num_workers=1,pin_memory=False)
    if rank == 0:
        print("Train dataset has %d samples total" % len(dataset))
    print("%s: Partition of train dataset has %d samples" % (hostname,
                                                             len(partition)))

    # Loss function
    loss_fn = get_loss_fn(args.loss,args.layer_lambdas)

    # Optimizer
    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.learning_rate)
    lrs_step_size = args.num_iters // (args.lr_steps+1)
    scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=lrs_step_size,
                                          gamma=0.1)

    # Stats
    if not os.path.isdir(args.results_dir):
        os.mkdir(args.results_dir)
    results_fn = 'r%d_' % rank + args.out_data_file
    results_path = os.path.join(args.results_dir,results_fn)
    loss_data = [] # records loss every args.record_loss_every iters

    # Training loop:
    iter = 0
    epoch_count = 0
    ave_iter_time = 0.0
    ave_reduce_time = 0.0
    while iter < args.num_iters:
        epoch_count += 1
        for X in train_loader:
            iter += 1
            optimizer.zero_grad()
            # Forward
            iter_tick = time.time()
            output = model(X)
            # Compute loss
            if args.loss == 'E':
                loss = loss_fn(output)
            else:
                X_no_t0 = X[:,1:,:,:,:]
                loss = loss_fn(output,X_no_t0)
            # Backward pass
            loss.backward()
            iter_tock = time.time()
            # All reduce: average gradients
            reduce_tick = time.time()
            average_gradients(model) # average gradients across all models
            reduce_tock = time.time()
            # Optimizer, scheduler
            optimizer.step()
            scheduler.step()
            # Time stats
            iter_time = iter_tock - iter_tick
            reduce_time = reduce_tock - reduce_tick
            ave_iter_time = (ave_iter_time*(iter-1) + iter_time)/iter
            ave_reduce_time = (ave_reduce_time*(iter-1) + reduce_time)/iter
            # Record loss
            if iter % args.record_loss_every == 0:
                loss_datapoint = loss.data.item()
                print(hostname,
                      'Rank:',rank,
                      'Epoch:', epoch_count,
                      'Iter:', iter,
                      'Ave iter time:',ave_iter_time,
                      'Ave reduce time:',ave_reduce_time,
                      'Loss:', loss_datapoint,
                      'lr:', scheduler.get_lr())
                loss_data.append(loss_datapoint)
            if iter >= args.num_iters:
                break
        # Write stats file
        stats = {'loss_data':loss_data}
        with open(results_path, 'w') as f:
            json.dump(stats, f)
        if rank == 0 and args.checkpoint_path is not None:
            print("Saving weights to %s" % args.checkpoint_path)
            torch.save(model.state_dict(),
                       args.checkpoint_path)
Exemplo n.º 7
0
def main(args):
    # CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Data
    if args.dataset == 'KITTI':
        train_data = KITTI(args.train_data_path,args.train_sources_path,
                           args.seq_len)
        val_data = KITTI(args.val_data_path,args.val_sources_path,
                         args.seq_len)
        test_data = KITTI(args.test_data_path,args.test_sources_path,
                          args.seq_len)
    elif args.dataset == 'CCN':
        downsample_size = (args.downsample_size,args.downsample_size)
        train_data = CCN(args.train_data_path,args.seq_len,
                         downsample_size=downsample_size,
                         last_only=args.last_only)
        val_data = CCN(args.val_data_path,args.seq_len,
                       downsample_size=downsample_size,
                       last_only=args.last_only)
        test_data = CCN(args.test_data_path,args.seq_len,
                        downsample_size=downsample_size,
                        last_only=args.last_only)
    train_loader = DataLoader(train_data,args.batch_size,shuffle=True)
    val_loader = DataLoader(val_data,args.batch_size,shuffle=True)
    test_loader = DataLoader(test_data,args.batch_size,shuffle=True)

    # Model
    model_out = 'error' if args.loss == 'E' else 'pred'
    if args.model_type == 'PredNet':
        model = PredNet(args.in_channels,args.stack_sizes,args.R_stack_sizes,
                        args.A_kernel_sizes,args.Ahat_kernel_sizes,
                        args.R_kernel_sizes,args.use_satlu,args.pixel_max,
                        args.Ahat_act,args.satlu_act,args.error_act,
                        args.LSTM_act,args.LSTM_c_act,args.bias,
                        args.use_1x1_out,args.FC,args.dropout_p,
                        args.send_acts,args.no_ER,args.RAhat,args.no_A_conv,
                        args.higher_satlu,args.local_grad,args.conv_dilation,
                        args.use_BN,model_out,device)
    elif args.model_type == 'MultiConvLSTM':
        model = MultiConvLSTM(args.in_channels,args.R_stack_sizes,
                              args.R_kernel_sizes,args.use_satlu,args.pixel_max,
                              args.Ahat_act,args.satlu_act,args.error_act,
                              args.LSTM_act,args.LSTM_c_act,args.bias,
                              args.use_1x1_out,args.FC,args.local_grad,
                              model_out,device)
    elif args.model_type == 'ConvLSTM':
        model = ConvLSTM(args.in_channels,args.hidden_channels,args.kernel_size,
                         args.LSTM_act,args.LSTM_c_act,args.out_act,
                         args.bias,args.FC,device)
    elif args.model_type == 'LadderNet':
        model = LadderNet(args.in_channels,args.stack_sizes,args.R_stack_sizes,
                          args.A_kernel_sizes,args.Ahat_kernel_sizes,
                          args.R_kernel_sizes,args.conv_dilation,args.use_BN,
                          args.use_satlu,args.pixel_max,args.A_act,
                          args.Ahat_act,args.satlu_act,args.error_act,
                          args.LSTM_act,args.LSTM_c_act,args.bias,
                          args.use_1x1_out,args.FC,args.no_R0,args.no_skip0,
                          args.no_A_conv,args.higher_satlu,args.local_grad,
                          model_out,device)
    elif args.model_type == 'StackedConvLSTM':
        model = StackedConvLSTM(args.in_channels,args.R_stack_sizes,
                                args.R_kernel_sizes,args.use_1x1_out,
                                args.FC,args.local_grad,args.forward_conv,
                                model_out,device)
    print(model)
    if args.load_weights_from is not None:
        model.load_state_dict(torch.load(args.load_weights_from))
    model.to(device)
    model.train()

    # Select loss function
    loss_fn = get_loss_fn(args.loss,args.layer_lambdas)
    loss_fn = loss_fn.to(device)

    # Optimizer
    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.learning_rate,weight_decay=args.wd)
    lrs_step_size = args.num_iters // (args.lr_steps+1)
    scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=lrs_step_size,
                                          gamma=0.1)

    # Stats
    ave_time = 0.0
    loss_data = [] # records loss every args.record_loss_every iters
    corr_data = [] # records correlation every args.record_loss_every iters
    train_losses = [] # records mean training loss every checkpoint
    train_corrs = [] # records mean training  correlation every checkpoint
    val_losses = [] # records mean validation loss every checkpoint
    val_corrs = [] # records mean validation correlation every checkpoint
    test_losses = [] # records mean test loss every checkpoint
    test_corrs = [] # records mean test correlation every checkpoint
    best_val_loss = float("inf") # will only save best weights
    if args.record_E:
        E_data = {'layer%d' % i:[] for i in range(model.nb_layers)}
        train_Es = {'layer%d' % i:[] for i in range(model.nb_layers)}
        val_Es = {'layer%d' % i:[] for i in range(model.nb_layers)}
        test_Es = {'layer%d' % i:[] for i in range(model.nb_layers)}
    # Training loop
    iter = 0
    epoch_count = 0
    while iter < args.num_iters:
        epoch_count += 1
        for X in train_loader:
            iter += 1
            optimizer.zero_grad()
            # Forward
            start_t = time.time()
            X = X.to(device)
            output = model(X)
            # Compute loss
            if args.loss == 'E':
                loss = loss_fn(output)
            else:
                X_no_t0 = X[:,1:,:,:,:]
                loss = loss_fn(output,X_no_t0)
            # Backward pass
            loss.backward()
            optimizer.step()
            scheduler.step()
            # Record loss
            iter_time = time.time() - start_t
            ave_time = (ave_time*(iter-1) + iter_time)/iter
            if iter % args.record_loss_every == 0:
                loss_datapoint = loss.data.item()
                print('Epoch:', epoch_count,
                      'Iter:', iter,
                      'Loss:', loss_datapoint,
                      'lr:', scheduler.get_lr(),
                      'ave time: ', ave_time)
                loss_data.append(loss_datapoint)
                if args.record_E:
                    E_means = torch.mean(output.detach(),dim=0)
                    for l in range(model.nb_layers):
                        E_datapoint = E_means[l].data.item()
                        E_data['layer%d' % l].append(E_datapoint)
                if args.record_corr:
                    model_output = model.output
                    model.output = 'pred'
                    output = model(X)
                    X_no_t0 = X[:,1:,:,:,:]
                    corr = correlation(output,X_no_t0)
                    corr_data.append(corr.data.item())
                    model.output = model_output
            if iter >= args.num_iters:
                break
        # Checkpoint
        last_epoch = (iter >= args.num_iters)
        if epoch_count % args.checkpoint_every == 0 or last_epoch:
            # Train
            print("Checking training loss...")
            train_checkpoint = checkpoint(train_loader,model,device,args)
            if args.record_E:
                train_loss,train_corr,train_E = train_checkpoint
                for l in range(model.nb_layers):
                    train_Es['layer%d' % l].append(train_E[l])
            else:
                train_loss,train_corr = train_checkpoint
            print("Training loss is ", train_loss)
            print("Training average correlation is ", train_corr)
            train_losses.append(train_loss)
            train_corrs.append(train_corr)
            # Validation
            print("Checking validation loss...")
            val_checkpoint = checkpoint(val_loader,model,device,args)
            if args.record_E:
                val_loss,val_corr,val_E = val_checkpoint
                for l in range(model.nb_layers):
                    val_Es['layer%d' % l].append(val_E[l])
            else:
                val_loss,val_corr = val_checkpoint
            print("Validation loss is ", val_loss)
            print("Validation average correlation is ",val_corr)
            val_losses.append(val_loss)
            val_corrs.append(val_corr)
            # Test
            print("Checking test loss...")
            test_checkpoint = checkpoint(test_loader,model,device,args)
            if args.record_E:
                test_loss,test_corr,test_E = test_checkpoint
                for l in range(model.nb_layers):
                    test_Es['layer%d' % l].append(test_E[l])
            else:
                test_loss,test_corr = test_checkpoint
            print("Test loss is ", test_loss)
            print("Test average correlation is ", test_corr)
            test_losses.append(test_loss)
            test_corrs.append(test_corr)
            # Write stats file
            if not os.path.isdir(args.results_dir):
                os.mkdir(args.results_dir)
            stats = {'loss_data':loss_data,
                     'corr_data':corr_data,
                     'train_mse_losses':train_losses,
                     'train_corrs':train_corrs,
                     'val_mse_losses':val_losses,
                     'val_corrs':val_corrs,
                     'test_mse_losses':test_losses,
                     'test_corrs':test_corrs}
            if args.record_E:
                stats['E_data'] = E_data
                stats['train_Es'] = train_Es
                stats['val_Es'] = val_Es
                stats['test_Es'] = test_Es
            results_file_name = '%s/%s' % (args.results_dir,args.out_data_file)
            with open(results_file_name, 'w') as f:
                json.dump(stats, f)
            # Save model weights
            if val_loss < best_val_loss: # use val (not test) to decide to save
                best_val_loss = val_loss
                if args.checkpoint_path is not None:
                    torch.save(model.state_dict(),
                               args.checkpoint_path)
Exemplo n.º 8
0
def main(args):
    # CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Load model
    model_out = 'rep'  # Always rep to get representations
    if args.model_type == 'PredNet':
        model = PredNet(args.in_channels, args.stack_sizes, args.R_stack_sizes,
                        args.A_kernel_sizes, args.Ahat_kernel_sizes,
                        args.R_kernel_sizes, args.use_satlu, args.pixel_max,
                        args.Ahat_act, args.satlu_act, args.error_act,
                        args.LSTM_act, args.LSTM_c_act, args.bias,
                        args.use_1x1_out, args.FC, args.dropout_p,
                        args.send_acts, args.no_ER, args.RAhat, args.no_A_conv,
                        args.higher_satlu, args.local_grad, args.conv_dilation,
                        args.use_BN, model_out, device)
    elif args.model_type == 'MultiConvLSTM':
        model = MultiConvLSTM(args.in_channels, args.R_stack_sizes,
                              args.R_kernel_sizes, args.use_satlu,
                              args.pixel_max, args.Ahat_act, args.satlu_act,
                              args.error_act, args.LSTM_act, args.LSTM_c_act,
                              args.bias, args.use_1x1_out, args.FC,
                              args.local_grad, model_out, device)
    elif args.model_type == 'ConvLSTM':
        model = ConvLSTM(args.in_channels, args.hidden_channels,
                         args.kernel_size, args.LSTM_act, args.LSTM_c_act,
                         args.out_act, args.bias, args.FC, device)
    elif args.model_type == 'LadderNet':
        model = LadderNet(
            args.in_channels, args.stack_sizes, args.R_stack_sizes,
            args.A_kernel_sizes, args.Ahat_kernel_sizes, args.R_kernel_sizes,
            args.conv_dilation, args.use_BN, args.use_satlu, args.pixel_max,
            args.A_act, args.Ahat_act, args.satlu_act, args.error_act,
            args.LSTM_act, args.LSTM_c_act, args.bias, args.use_1x1_out,
            args.FC, args.no_R0, args.no_skip0, args.no_A_conv,
            args.higher_satlu, args.local_grad, model_out, device)
    elif args.model_type == 'StackedConvLSTM':
        model = StackedConvLSTM(args.in_channels, args.R_stack_sizes,
                                args.R_kernel_sizes, args.use_1x1_out, args.FC,
                                args.local_grad, args.forward_conv, model_out,
                                device)
    if args.load_weights_from is not None:
        model.load_state_dict(torch.load(args.load_weights_from))
    model.to(device)
    model.eval()
    if args.model_type == 'LadderNet' and args.no_R0:
        nb_reps = model.nb_layers - 1
    else:
        nb_reps = model.nb_layers

    # Dataset
    downsample_size = (args.downsample_size, args.downsample_size)
    test_data = CCN(args.test_data_path,
                    args.seq_len,
                    downsample_size=downsample_size,
                    return_labels=True,
                    last_only=args.last_only)
    partitioner = Partitioner(test_data, args.idx_dict_hkl)
    labels = sorted(partitioner.labels)
    n_labels = len(labels)
    print("There are %d labels in the dataset" % n_labels)

    with torch.no_grad():
        # Get list of layer representations for each label
        label_reps = []
        for label_i, label in enumerate(labels):
            # Get data partition for current label
            print("Starting label %d/%d: %s" % (label_i + 1, n_labels, label))
            partition = partitioner.get_partition(label)
            n_samples = len(partition)
            dataloader = DataLoader(partition, args.batch_size)
            # Run model, keeping running sum of representations
            layer_reps = [[] for l in range(nb_reps + 1)]  # nb_reps + pixels
            for batch_i, batch in enumerate(dataloader):
                X = batch[0].to(device)
                # Get representations
                reps = model(X)  # list of reps, one for each layer
                pixels = X[:,
                           -1, :, :, :]  # Use last image to compare to RGB reps
                # Aggregate across space
                agg_pixels = aggregate_space(pixels, args.aggregate_method)
                agg_reps = [agg_pixels]  # first layer is pixels
                for l in range(nb_reps):
                    agg_rep = aggregate_space(reps[l], args.aggregate_method)
                    agg_reps.append(agg_rep)
                # Sum batch
                layer_sums = []
                for l in range(nb_reps + 1):
                    layer_sum = torch.sum(agg_reps[l], dim=0)
                    layer_sum = layer_sum.unsqueeze(0)  # Add label dimension
                    layer_sums.append(layer_sum)
                # Update running sums
                if batch_i == 0:
                    layer_reps = layer_sums
                else:
                    for l in range(nb_reps + 1):
                        layer_reps[l] += layer_sums[l]
            # Divide by n_samples to get average
            layer_reps = [layer_rep / n_samples for layer_rep in layer_reps]
            label_reps.append(layer_reps)
            print("Finished processing samples for label: %s" % label)
        layer_lists = list(map(list, zip(*label_reps)))  # transpose lists
        layer_tensors = []
        for l in range(nb_reps + 1):
            layer_tensor = torch.cat(layer_lists[l], dim=0)
            layer_tensors.append(layer_tensor)

    # Set up data for saving similarity matrices
    info = {
        'aggregate_method': args.aggregate_method,
        'similarity_measure': args.similarity_measure
    }
    RSA_data = {'info': info}
    if args.cat_dict_json is None:
        cats = set([l.split('_')[0] for l in labels])
        cat_dict = {cat: cat for cat in cats}
    else:
        with open(args.cat_dict_json, 'r') as f:
            cat_dict = json.load(f)
    print("Computing and sorting similarity matrices")
    for l, layer_tensor in enumerate(layer_tensors):
        # Get similarity matrix
        S = get_similarity_matrix(layer_tensor, args.similarity_measure)
        S = S.cpu().numpy()
        # Sort similarity matrix
        sorted_S, sorted_labels = sort_similarity_matrix(S, cat_dict, labels)
        # Save matrices
        layer_name = 'layer%d' % (l - 1) if l > 0 else 'pixels'
        RSA_data[layer_name] = sorted_S
    RSA_data['labels'] = sorted_labels  # all sorted labels should be the same

    # Save similarity matrices
    dir = args.results_dir
    if not os.path.isdir(dir):
        os.mkdir(dir)
    results_path = os.path.join(args.results_dir, args.out_data_file)
    print("Saving results to %s" % results_path)
    hkl.dump(RSA_data, results_path)
Exemplo n.º 9
0
batch_size = 5
sequence_length = 8
epoch = 20
image_shape = 64
ConvLSTM_channel = 128
pose_generator_basefeature = 128
model_file = 'model_it_1943'
train_player_list = [[5, 6], [7, 8], [9, 10], [11, 12]]
test_player_list = [[13, 14]]
sigma = 0.01
test_interval = 1
eval_interval = 1
#model init
pose_generator = Generator(pose_generator_basefeature)
pose_discriminator = Discriminator(ConvLSTM_channel, image_shape)
convlstm_model = ConvLSTM(ConvLSTM_channel, sequence_length)

#data loader init
train_pose_dataset = Pose_Dataset(Boxing_dir, sequence_length, player_list, len(player_list),sigma,\
      epoch, grid_point = image_shape)
test_pose_dataset = Pose_Dataset(Boxing_dir, sequence_length, test_player_list, len(test_player_list),sigma,\
      epoch, grid_point = image_shape, mode = 'test')
train_loader = torch.utils.data.DataLoader(dataset=train_pose_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4)
test_pose_dataset = torch.utils.data.DataLoader(dataset=train_pose_dataset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=4)
G_optmizer = torch.optim.Adam(pose_generator.parameters(),
Exemplo n.º 10
0
def main(args):
    # CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Data: Don't shuffle to keep indexes consistent
    if args.dataset == 'KITTI':
        test_data = KITTI(args.test_data_path, args.test_sources_path,
                          args.seq_len)
    elif args.dataset == 'CCN':
        downsample_size = (args.downsample_size, args.downsample_size)
        test_data = CCN(args.test_data_path,
                        args.seq_len,
                        downsample_size=downsample_size,
                        last_only=args.last_only)

    # Load model
    model_out = 'pred'  # Always pred to get predicted images
    if args.model_type == 'PredNet':
        model = PredNet(args.in_channels, args.stack_sizes, args.R_stack_sizes,
                        args.A_kernel_sizes, args.Ahat_kernel_sizes,
                        args.R_kernel_sizes, args.use_satlu, args.pixel_max,
                        args.Ahat_act, args.satlu_act, args.error_act,
                        args.LSTM_act, args.LSTM_c_act, args.bias,
                        args.use_1x1_out, args.FC, args.dropout_p,
                        args.send_acts, args.no_ER, args.RAhat, args.no_A_conv,
                        args.higher_satlu, args.local_grad, args.conv_dilation,
                        args.use_BN, model_out, device)
    elif args.model_type == 'MultiConvLSTM':
        model = MultiConvLSTM(args.in_channels, args.R_stack_sizes,
                              args.R_kernel_sizes, args.use_satlu,
                              args.pixel_max, args.Ahat_act, args.satlu_act,
                              args.error_act, args.LSTM_act, args.LSTM_c_act,
                              args.bias, args.use_1x1_out, args.FC,
                              args.local_grad, model_out, device)
    elif args.model_type == 'ConvLSTM':
        model = ConvLSTM(args.in_channels, args.hidden_channels,
                         args.kernel_size, args.LSTM_act, args.LSTM_c_act,
                         args.out_act, args.bias, args.FC, device)
    elif args.model_type == 'LadderNet':
        model = LadderNet(
            args.in_channels, args.stack_sizes, args.R_stack_sizes,
            args.A_kernel_sizes, args.Ahat_kernel_sizes, args.R_kernel_sizes,
            args.conv_dilation, args.use_BN, args.use_satlu, args.pixel_max,
            args.A_act, args.Ahat_act, args.satlu_act, args.error_act,
            args.LSTM_act, args.LSTM_c_act, args.bias, args.use_1x1_out,
            args.FC, args.no_R0, args.no_skip0, args.no_A_conv,
            args.higher_satlu, args.local_grad, model_out, device)
    elif args.model_type == 'StackedConvLSTM':
        model = StackedConvLSTM(args.in_channels, args.R_stack_sizes,
                                args.R_kernel_sizes, args.use_1x1_out, args.FC,
                                args.local_grad, args.forward_conv, model_out,
                                device)

    if args.load_weights_from is not None:
        model.load_state_dict(torch.load(args.load_weights_from))
    model.to(device)

    # Get random indices of sequences to save
    total_seqs = len(test_data)
    seq_ids = np.random.choice(np.arange(total_seqs),
                               size=args.num_seqs,
                               replace=False)

    dir = args.results_dir
    if not os.path.isdir(dir):
        os.mkdir(dir)

    # Get predicted images
    model.eval()
    with torch.no_grad():
        for num, i in enumerate(seq_ids):
            if args.sanity_check:  # Get first part of seq i, second part of i+1
                X_i = test_data[i]
                next_i = seq_ids[(num + 1) % args.num_seqs]
                X_ip1 = test_data[next_i]
                halfway = args.seq_len // 2
                X = torch.cat((X_i[:halfway], X_ip1[halfway:]), dim=0)
                X = X.to(device)
            else:
                X = test_data[i].to(device)
            X = X.unsqueeze(0)  # Add batch dim
            seq_len = X.shape[1]
            preds = model(X)
            preds = preds.squeeze(0).permute(0, 2, 3, 1)  # (len,H,W,channels)
            preds = preds.cpu().numpy()
            X = X.squeeze(0).permute(0, 2, 3, 1)  # (len,H,W,channels)
            X = X.cpu().numpy()
            if test_data.norm:
                X = np.round(X * 255.)
                preds = np.round(preds * 255.)
            for t in range(seq_len):
                X_t = np.uint8(X[t])
                X_img = Image.fromarray(X_t)
                fn = args.out_data_file
                X_img_path = '%s/%s_X%d_t%d.png' % (dir, fn, i, t)
                print("Saving image at %s" % X_img_path)
                X_img.save(X_img_path)
                if t < seq_len - 1:  # 1 less prediction
                    preds_t = np.uint8(preds[t])
                    pred_img = Image.fromarray(preds_t)
                    pred_img_path = '%s/%s_pred%d_t%d.png' % (dir, fn, i,
                                                              t + 1)
                    print("Saving image at %s" % pred_img_path)
                    pred_img.save(pred_img_path)
        print("Done")