def runner(args): if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.is_cuda: print('CUDA is set to DO NOT USE') device = torch.device("cpu") data_set = np.load(args.data_path + 'mnist_test_seq.npy') # test_set = data_set[:, :1000] # train_set = data_set[:, 1000:7000] # valid_set = data_set[:, 7000:] # test_set = data_set[:, :1000] train_set = data_set[:, :9000] valid_set = data_set[:, 9000:] del data_set if args.is_quickrun: train_set = train_set[:, :15] valid_set = valid_set[:, :10] args.batch_size = 5 args.epoch = 2 def input_target_maker(batch, device): batch = batch / 255. input_x = batch[:10, :, :, :] pred_target = batch[10:, :, :, :] rec_target = np.flip(batch[:10, :, :, :], axis=0) rec_target = np.ascontiguousarray(rec_target) input_x = torch.Tensor(input_x).to(device) pred_target = torch.Tensor(pred_target).to(device) rec_target = torch.Tensor(rec_target).to(device) return input_x, rec_target, pred_target if args.model == 'ED_R_01': from models import fc_lstm_v0001 as m elif args.model == 'b': from models.ED_lstmcell_v0001 import FC_LSTM as m elif args.model == 'c': from models.ED_lstmcell_v0001 import FC_LSTM as m elif args.model == 'd': from models.ED_lstmcell_v0001 import FC_LSTM as m else: raise Exception('wrong model') model = m(args).to(device) model.apply(weight_init) # model.to(device) if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) else: optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.9) log = Log(args) best_loss = 99999 for e in range(args.epoch): batch_mse_recon_loss = [] batch_bce_recon_loss = [] batch_mse_pred_loss = [] batch_bce_pred_loss = [] idx = np.random.permutation(len(train_set[0])) # i.i.d. sampling for i in range(len(train_set[0]) // args.batch_size): model.train() input_x, rec_target, pred_target =\ input_target_maker( train_set[:, idx[i:i + args.batch_size]], device) optimizer.zero_grad() rec1, rec2 = model(input_x) if args.mode == 'pred': mse_train_loss = F.mse_loss(rec1, pred_target) bce_train_loss = F.binary_cross_entropy(rec1, pred_target) batch_mse_recon_loss.append(0) batch_bce_recon_loss.append(0) batch_mse_pred_loss.append(mse_train_loss.item()) batch_bce_pred_loss.append(bce_train_loss.item()) elif args.mode == 'recon': mse_train_loss = F.mse_loss(rec1, rec_target) bce_train_loss = F.binary_cross_entropy(rec1, rec_target) batch_mse_recon_loss.append(mse_train_loss.item()) batch_bce_recon_loss.append(bce_train_loss.item()) batch_mse_pred_loss.append(0) batch_bce_pred_loss.append(0) else: # 'both' mse_train_recon_loss = F.mse_loss(rec1, rec_target) bce_train_recon_loss = F.binary_cross_entropy(rec1, rec_target) mse_train_pred_loss = F.mse_loss(rec2, pred_target) bce_train_pred_loss = F.binary_cross_entropy(rec2, pred_target) batch_mse_recon_loss.append(mse_train_recon_loss.item()) batch_bce_recon_loss.append(bce_train_recon_loss.item()) batch_mse_pred_loss.append(mse_train_pred_loss.item()) batch_bce_pred_loss.append(bce_train_pred_loss.item()) bce_train_loss = args.recon_loss_lambda * bce_train_recon_loss + bce_train_pred_loss loss = bce_train_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) scheduler.step() episode_train_mse_recon_loss = np.mean(batch_mse_recon_loss) episode_train_bce_recon_loss = np.mean(batch_bce_recon_loss) episode_train_mse_pred_loss = np.mean(batch_mse_pred_loss) episode_train_bce_pred_loss = np.mean(batch_bce_pred_loss) batch_mse_recon_loss = [] batch_bce_recon_loss = [] batch_mse_pred_loss = [] batch_bce_pred_loss = [] for i in range(len(valid_set[0]) // args.batch_size): with torch.no_grad(): model.eval() input_x, rec_target, pred_target =\ input_target_maker( valid_set[:, i:i + args.batch_size], device) rec1, rec2 = model(input_x) if args.mode == 'pred': mse_train_loss = F.mse_loss(rec1, pred_target) bce_train_loss = F.binary_cross_entropy(rec1, pred_target) batch_mse_recon_loss.append(0) batch_bce_recon_loss.append(0) batch_mse_pred_loss.append(mse_train_loss.item()) batch_bce_pred_loss.append(bce_train_loss.item()) elif args.mode == 'recon': mse_train_loss = F.mse_loss(rec1, rec_target) bce_train_loss = F.binary_cross_entropy(rec1, rec_target) batch_mse_recon_loss.append(mse_train_loss.item()) batch_bce_recon_loss.append(bce_train_loss.item()) batch_mse_pred_loss.append(0) batch_bce_pred_loss.append(0) else: # 'both' mse_train_recon_loss = F.mse_loss(rec1, rec_target) bce_train_recon_loss = F.binary_cross_entropy(rec1, rec_target) mse_train_pred_loss = F.mse_loss(rec2, pred_target) bce_train_pred_loss = F.binary_cross_entropy(rec2, pred_target) batch_mse_recon_loss.append(mse_train_recon_loss.item()) batch_bce_recon_loss.append(bce_train_recon_loss.item()) batch_mse_pred_loss.append(mse_train_pred_loss.item()) batch_bce_pred_loss.append(bce_train_pred_loss.item()) episode_val_mse_recon_loss = np.mean(batch_mse_recon_loss) episode_val_bce_recon_loss = np.mean(batch_bce_recon_loss) episode_val_mse_pred_loss = np.mean(batch_mse_pred_loss) episode_val_bce_pred_loss = np.mean(batch_bce_pred_loss) # T: train, V: validation, M: mse, B: bce, R: recon, P: pred log_string = 'Epoch: {}, TMR: {}, TBR: {}, TMP: {}, TBP: {}, VMR: {}, VBR: {}, VMP: {}, VBP: {}'\ .format(e, episode_train_mse_recon_loss, episode_train_bce_recon_loss, episode_train_mse_pred_loss, episode_train_bce_pred_loss, episode_val_mse_recon_loss, episode_val_bce_recon_loss, episode_val_mse_pred_loss, episode_val_bce_pred_loss) log.log(log_string) if args.is_save: total_loss = episode_val_mse_recon_loss if total_loss < best_loss: best_loss = total_loss if not os.path.exists(args.save_path): os.makedirs(args.save_path) torch.save(model, args.model_save_file)
def runner_sigmoid(args, path): # torch.set_default_dtype(torch.float16) if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.is_cuda: print('CUDA is set to DO NOT USE') device = torch.device("cpu") data_set = np.load(path + 'mnist_test_seq.npy') # test_set = data_set[:, :1000] # train_set = data_set[:, 1000:7000] # valid_set = data_set[:, 7000:] # test_set = data_set[:, :1000] train_set = data_set[:, :9000] valid_set = data_set[:, 9000:] del data_set if args.is_quickrun: train_set = train_set[:, :15] valid_set = valid_set[:, :10] args.batch_size = 5 args.epoch = 2 def input_target_maker(batch, device): batch = batch / 255. if args.is_standardization: batch = (batch - 0.5) / 0.5 input_x = batch[:10, :, :, :] pred_target = batch[10:, :, :, :] rec_target = np.flip(batch[:10, :, :, :], axis=0) rec_target = np.ascontiguousarray(rec_target) input_x = torch.Tensor(input_x).to(device) pred_target = torch.Tensor(pred_target).to(device) rec_target = torch.Tensor(rec_target).to(device) return input_x, rec_target, pred_target if args.model == 'ED_R_01': from models import lstm_copy as m elif args.model == 'b': from models.ED_lstmcell_v0001 import FC_LSTM as m elif args.model == 'c': from models.ED_lstmcell_v0001 import FC_LSTM as m elif args.model == 'd': from models.ED_lstmcell_v0001 import FC_LSTM as m else: raise Exception('wrong model') model = m(args).to(device) optimizer = optim.Adam(model.parameters()) log = Log(args) best_loss = 99999 for e in range(args.epoch): batch_mse_recon_loss = [] batch_bce_recon_loss = [] batch_mse_pred_loss = [] batch_bce_pred_loss = [] idx = np.random.permutation(len(train_set[0])) # i.i.d. sampling for i in range(len(train_set[0]) // args.batch_size): model.train() input_x, rec_target, pred_target =\ input_target_maker( train_set[:, idx[i:i + args.batch_size]], device) # print(type(input_x)) # print(input_x.shape) # print(type(input_x[0,0,0,0])) # print(input_x.dtype) # raise Exception('hi') optimizer.zero_grad() rec1, rec2 = model(input_x) if args.mode == 'pred': mse_train_loss = F.mse_loss(rec1, pred_target) bce_train_loss = F.binary_cross_entropy(rec1, pred_target) batch_mse_recon_loss.append(0) batch_bce_recon_loss.append(0) batch_mse_pred_loss.append(mse_train_loss.item()) batch_bce_pred_loss.append(bce_train_loss.item()) elif args.mode == 'recon': mse_train_loss = F.mse_loss(rec1, rec_target) bce_train_loss = F.binary_cross_entropy(rec1, rec_target) batch_mse_recon_loss.append(mse_train_loss.item()) batch_bce_recon_loss.append(bce_train_loss.item()) batch_mse_pred_loss.append(0) batch_bce_pred_loss.append(0) else: # 'both' mse_train_recon_loss = F.mse_loss(rec1, rec_target) bce_train_recon_loss = F.binary_cross_entropy(rec1, rec_target) mse_train_pred_loss = F.mse_loss(rec2, pred_target) bce_train_pred_loss = F.binary_cross_entropy(rec2, pred_target) batch_mse_recon_loss.append(mse_train_recon_loss.item()) batch_bce_recon_loss.append(bce_train_recon_loss.item()) batch_mse_pred_loss.append(mse_train_pred_loss.item()) batch_bce_pred_loss.append(bce_train_pred_loss.item()) mse_train_loss = args.recon_loss_lambda * mse_train_recon_loss + mse_train_pred_loss bce_train_loss = args.recon_loss_lambda * bce_train_recon_loss + bce_train_pred_loss if args.loss_function == 'mse': loss = mse_train_loss else: # 'bce' loss = bce_train_loss loss.backward() optimizer.step() episode_train_mse_recon_loss = np.mean(batch_mse_recon_loss) episode_train_bce_recon_loss = np.mean(batch_bce_recon_loss) episode_train_mse_pred_loss = np.mean(batch_mse_pred_loss) episode_train_bce_pred_loss = np.mean(batch_bce_pred_loss) batch_mse_recon_loss = [] batch_bce_recon_loss = [] batch_mse_pred_loss = [] batch_bce_pred_loss = [] for i in range(len(valid_set[0]) // args.batch_size): with torch.no_grad(): model.eval() input_x, rec_target, pred_target =\ input_target_maker( valid_set[:, i:i + args.batch_size], device) rec1, rec2 = model(input_x) if args.mode == 'pred': mse_train_loss = F.mse_loss(rec1, pred_target) bce_train_loss = F.binary_cross_entropy(rec1, pred_target) batch_mse_recon_loss.append(0) batch_bce_recon_loss.append(0) batch_mse_pred_loss.append(mse_train_loss.item()) batch_bce_pred_loss.append(bce_train_loss.item()) elif args.mode == 'recon': mse_train_loss = F.mse_loss(rec1, rec_target) bce_train_loss = F.binary_cross_entropy(rec1, rec_target) batch_mse_recon_loss.append(mse_train_loss.item()) batch_bce_recon_loss.append(bce_train_loss.item()) batch_mse_pred_loss.append(0) batch_bce_pred_loss.append(0) else: # 'both' mse_train_recon_loss = F.mse_loss(rec1, rec_target) bce_train_recon_loss = F.binary_cross_entropy( rec1, rec_target) mse_train_pred_loss = F.mse_loss(rec2, pred_target) bce_train_pred_loss = F.binary_cross_entropy( rec2, pred_target) batch_mse_recon_loss.append(mse_train_recon_loss.item()) batch_bce_recon_loss.append(bce_train_recon_loss.item()) batch_mse_pred_loss.append(mse_train_pred_loss.item()) batch_bce_pred_loss.append(bce_train_pred_loss.item()) episode_val_mse_recon_loss = np.mean(batch_mse_recon_loss) episode_val_bce_recon_loss = np.mean(batch_bce_recon_loss) episode_val_mse_pred_loss = np.mean(batch_mse_pred_loss) episode_val_bce_pred_loss = np.mean(batch_bce_pred_loss) # T: train, V: validation, M: mse, B: bce, R: recon, P: pred log_string = 'Epoch: {}, TMR: {}, TBR: {}, TMP: {}, TBP: {}, VMR: {}, VBR: {}, VMP: {}, VBP: {}'\ .format(e, episode_train_mse_recon_loss, episode_train_bce_recon_loss, episode_train_mse_pred_loss, episode_train_bce_pred_loss, episode_val_mse_recon_loss, episode_val_bce_recon_loss, episode_val_mse_pred_loss, episode_val_bce_pred_loss) log.log(log_string) if args.is_save: total_loss = episode_val_mse_recon_loss if total_loss < best_loss: best_loss = total_loss if not os.path.exists(args.save_path): os.makedirs(args.save_path) torch.save(model, args.model_save_file)