Beispiel #1
0
    def __init__(self,
                 embedding_dim,
                 input_dims,
                 hidden_dim,
                 num_slots,
                 encoder='cswm',
                 cnn_size='small',
                 decoder='broadcast',
                 trans_model='gnn',
                 identity_action=False,
                 residual=False,
                 canonical=False):
        super(NodModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.input_dims = input_dims
        self.hidden_dims = hidden_dim

        self.num_slots = num_slots
        self.identity_action_flag = identity_action
        self.canonical = canonical

        if encoder == 'cswm':
            self.encoder = modules.EncoderCSWM(
                input_dims=self.input_dims,
                embedding_dim=self.embedding_dim,
                num_objects=self.num_slots,
                cnn_size=cnn_size)

        if trans_model == 'gnn':
            self.transition_model = modules.TransitionGNN(
                input_dim=self.embedding_dim,
                hidden_dim=512,
                action_dim=12,
                num_objects=self.num_slots,
                residual=residual)
        elif trans_model == 'attention':
            self.transition_model = attention.MultiHeadCondAttention(
                n_head=5,
                input_feature_dim=self.embedding_dim + 12,
                out_dim=self.embedding_dim,
                dim_k=128,
                dim_v=128)

        if decoder == 'broadcast':
            self.decoder = spd.BroadcastDecoder(
                latent_dim=self.embedding_dim,
                output_dim=4,  # 3 rgb channels and one mask
                hidden_channels=32,
                num_layers=4,
                img_dims=self.
                input_dims[1:],  # width and height of square image
                act_fn='elu')
        elif decoder == 'cnn':
            out_shape = self.input_dims
            out_shape[0] += 1
            self.decoder = modules.DecoderCNNMedium(
                input_dim=self.embedding_dim,
                hidden_dim=32,
                num_objects=self.num_slots,
                output_size=out_shape)

        print('Number of params in encoder ', util.count_params(self.encoder))
        print(f'Number of params in transition model ',
              util.count_params(self.transition_model))
        print('Number of params in decoder ', util.count_params(self.decoder))

        self.l2_loss = nn.MSELoss(reduction="mean")
Beispiel #2
0
                               copy_action=args.copy_action,
                               encoder=args.encoder).to(device)

model.apply(utils.weights_init)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

if args.decoder:
    if args.encoder == 'large':
        decoder = modules.DecoderCNNLarge(input_dim=args.embedding_dim,
                                          num_objects=args.num_objects,
                                          hidden_dim=args.hidden_dim // 16,
                                          output_size=input_shape).to(device)
    elif args.encoder == 'medium':
        decoder = modules.DecoderCNNMedium(input_dim=args.embedding_dim,
                                           num_objects=args.num_objects,
                                           hidden_dim=args.hidden_dim // 16,
                                           output_size=input_shape).to(device)
    elif args.encoder == 'small':
        decoder = modules.DecoderCNNSmall(input_dim=args.embedding_dim,
                                          num_objects=args.num_objects,
                                          hidden_dim=args.hidden_dim // 16,
                                          output_size=input_shape).to(device)
    decoder.apply(utils.weights_init)
    optimizer_dec = torch.optim.Adam(decoder.parameters(),
                                     lr=args.learning_rate)

# Train model.
print('Starting model training...')
step = 0
best_loss = 1e9
Beispiel #3
0
def train_c_swm(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print("Inside train_c_swm")

    now = datetime.datetime.now()
    timestamp = now.isoformat()

    if args.name == 'none':
        exp_name = timestamp
    else:
        exp_name = args.name

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    exp_counter = 0
    save_folder = '{}/{}/'.format(args.save_folder, exp_name)

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    meta_file = os.path.join(save_folder, 'metadata.pkl')
    model_file = os.path.join(save_folder, 'model.pt')
    log_file = os.path.join(save_folder, 'log.txt')

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(log_file, 'a'))
    #print = logger.info

    pickle.dump({'args': args}, open(meta_file, "wb"))

    device = torch.device('cuda' if args.cuda else 'cpu')
    print("About to get dataset")
    dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset)
    train_loader = data.DataLoader(dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)
    print("Dataset loaded")
    # Get data sample
    obs = train_loader.__iter__().next()[0]
    input_shape = obs[0].size()

    model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim,
                                   hidden_dim=args.hidden_dim,
                                   action_dim=args.action_dim,
                                   input_dims=input_shape,
                                   num_objects=args.num_objects,
                                   sigma=args.sigma,
                                   hinge=args.hinge,
                                   ignore_action=args.ignore_action,
                                   copy_action=args.copy_action,
                                   encoder=args.encoder).to(device)

    model.apply(utils.weights_init)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.decoder:
        if args.encoder == 'large':
            decoder = modules.DecoderCNNLarge(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        elif args.encoder == 'medium':
            decoder = modules.DecoderCNNMedium(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        elif args.encoder == 'small':
            decoder = modules.DecoderCNNSmall(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        decoder.apply(utils.weights_init)
        optimizer_dec = torch.optim.Adam(decoder.parameters(),
                                         lr=args.learning_rate)

    # Train model.
    print('Starting model training...')
    step = 0
    best_loss = 1e9

    for epoch in range(1, args.epochs + 1):
        model.train()
        train_loss = 0

        for batch_idx, data_batch in enumerate(train_loader):
            data_batch = [tensor.to(device) for tensor in data_batch]
            optimizer.zero_grad()

            if args.decoder:
                optimizer_dec.zero_grad()
                obs, action, next_obs = data_batch
                objs = model.obj_extractor(obs)
                state = model.obj_encoder(objs)

                rec = torch.sigmoid(decoder(state))
                loss = F.binary_cross_entropy(rec, obs,
                                              reduction='sum') / obs.size(0)

                next_state_pred = state + model.transition_model(state, action)
                next_rec = torch.sigmoid(decoder(next_state_pred))
                next_loss = F.binary_cross_entropy(
                    next_rec, next_obs, reduction='sum') / obs.size(0)
                loss += next_loss
            else:
                loss = model.contrastive_loss(*data_batch)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if args.decoder:
                optimizer_dec.step()

            if batch_idx % args.log_interval == 0:
                print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_batch[0]),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data_batch[0])))

            step += 1

        avg_loss = train_loss / len(train_loader.dataset)
        # print('====> Epoch: {} Average loss: {:.6f}'.format(
        #     epoch, avg_loss))

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_file)

    return model