input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder).to(device) model.apply(utils.weights_init) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.decoder: if args.encoder == 'large': decoder = modules.DecoderCNNLarge(input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'medium': decoder = modules.DecoderCNNMedium(input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'small': decoder = modules.DecoderCNNSmall(input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) decoder.apply(utils.weights_init) optimizer_dec = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate)
def train_c_swm(args): args.cuda = not args.no_cuda and torch.cuda.is_available() print("Inside train_c_swm") now = datetime.datetime.now() timestamp = now.isoformat() if args.name == 'none': exp_name = timestamp else: exp_name = args.name np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) exp_counter = 0 save_folder = '{}/{}/'.format(args.save_folder, exp_name) if not os.path.exists(save_folder): os.makedirs(save_folder) meta_file = os.path.join(save_folder, 'metadata.pkl') model_file = os.path.join(save_folder, 'model.pt') log_file = os.path.join(save_folder, 'log.txt') logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger() logger.addHandler(logging.FileHandler(log_file, 'a')) #print = logger.info pickle.dump({'args': args}, open(meta_file, "wb")) device = torch.device('cuda' if args.cuda else 'cpu') print("About to get dataset") dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset) train_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) print("Dataset loaded") # Get data sample obs = train_loader.__iter__().next()[0] input_shape = obs[0].size() model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, action_dim=args.action_dim, input_dims=input_shape, num_objects=args.num_objects, sigma=args.sigma, hinge=args.hinge, ignore_action=args.ignore_action, copy_action=args.copy_action, encoder=args.encoder).to(device) model.apply(utils.weights_init) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.decoder: if args.encoder == 'large': decoder = modules.DecoderCNNLarge( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'medium': decoder = modules.DecoderCNNMedium( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) elif args.encoder == 'small': decoder = modules.DecoderCNNSmall( input_dim=args.embedding_dim, num_objects=args.num_objects, hidden_dim=args.hidden_dim // 16, output_size=input_shape).to(device) decoder.apply(utils.weights_init) optimizer_dec = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate) # Train model. print('Starting model training...') step = 0 best_loss = 1e9 for epoch in range(1, args.epochs + 1): model.train() train_loss = 0 for batch_idx, data_batch in enumerate(train_loader): data_batch = [tensor.to(device) for tensor in data_batch] optimizer.zero_grad() if args.decoder: optimizer_dec.zero_grad() obs, action, next_obs = data_batch objs = model.obj_extractor(obs) state = model.obj_encoder(objs) rec = torch.sigmoid(decoder(state)) loss = F.binary_cross_entropy(rec, obs, reduction='sum') / obs.size(0) next_state_pred = state + model.transition_model(state, action) next_rec = torch.sigmoid(decoder(next_state_pred)) next_loss = F.binary_cross_entropy( next_rec, next_obs, reduction='sum') / obs.size(0) loss += next_loss else: loss = model.contrastive_loss(*data_batch) loss.backward() train_loss += loss.item() optimizer.step() if args.decoder: optimizer_dec.step() if batch_idx % args.log_interval == 0: print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data_batch[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data_batch[0]))) step += 1 avg_loss = train_loss / len(train_loader.dataset) # print('====> Epoch: {} Average loss: {:.6f}'.format( # epoch, avg_loss)) if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), model_file) return model