def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # Setting up the resnet model if args.pretrained: resnet = torchvision.models.resnet50(pretrained=True, progress=True) else: resnet = torchvision.models.resnet50(pretrained=False, progress=True) num_features = resnet.fc.in_features resnet.fc = nn.Linear(num_features, args.num_classes) resnet = resnet.cuda() # Setting up the optimizer if args.optimizer == 'sgd': optimizer = optim.SGD(resnet.parameters(), lr=args.lr, weight_decay=1e-4) elif args.optimizer == 'sgd_momentum': optimizer = optim.SGD(resnet.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) elif args.optimizer == 'adam': optimizer = optim.Adam( filter(lambda p: p.requires_grad, resnet.parameters()), args.g_lr, (args.beta1, args.beta2)) else: optimizer = None assert optimizer != None criterion = nn.CrossEntropyLoss() if args.percentage == 1.0: train_data, val_data, test_data = get_train_validation_test_data( args.train_csv_path, args.train_img_path, args.val_csv_path, args.val_img_path, args.test_csv_path, args.test_img_path) else: train_data = get_label_unlabel_dataset(args.train_csv_path, args.train_img_path, args.percentage) _, val_data, test_data = get_train_validation_test_data( args.train_csv_path, args.train_img_path, args.val_csv_path, args.val_img_path, args.test_csv_path, args.test_img_path) train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) val_loader = DataLoader(val_data, batch_size=args.eval_batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) test_loader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) print('Training Datasize:', len(train_data)) start_epoch = 0 best_acc1 = 0 best_acc2 = 0 best_acc3 = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_last.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] resnet.load_state_dict(checkpoint['resnet_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } start = time.time() for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): best_curr_acc1, best_curr_acc2, best_curr_acc3 = train( args, resnet, optimizer, criterion, train_loader, val_loader, epoch, writer_dict, best_acc1, best_acc2, best_acc3) best_acc1, best_acc2, best_acc3 = best_curr_acc1, best_curr_acc2, best_curr_acc3 if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: val_acc = get_val_acc(val_loader, resnet) logger.info(f'Validation Accuracy {val_acc} || @ epoch {epoch}.') save_checkpoint( { 'epoch': epoch + 1, 'resnet_state_dict': resnet.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'path_helper': args.path_helper }, False, False, False, args.path_helper['ckpt_path'], filename='checkpoint_last.pth') end = time.time() final_val_acc = get_val_acc(val_loader, resnet) final_test_acc = get_val_acc(test_loader, resnet) time_elapsed = end - start print('\n Final Validation Accuracy:', final_val_acc.data, '\n Final Test Accuracy:', final_test_acc.data, '\n Time Elapsed:', time_elapsed, 'seconds.')
def run_trpo( env_id, max_iter=1e6, gamma=0.99, lr=3e-4, lam=0.95, epsilon=0.2, hidden1=64, hidden2=64, eval_interval=2000, steps_per_epoch=4000, device='cpu', render='False' ): max_iter = int(max_iter) env = gym.make(env_id) dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env) agent = TRPOAgent( dimS, dimA, ctrl_range, gamma=gamma, lr=lr, lam=lam, epsilon=epsilon, hidden1=hidden1, hidden2=hidden2, mem_size=steps_per_epoch, device=device, render=render ) set_log_dir(env_id) current_time = time.strftime("%m%d-%H%M%S") train_log = open('./train_log/' + env_id + '/PPO_' + current_time + '.csv', 'w', encoding='utf-8', newline='') eval_log = open('./eval_log/' + env_id + '/PPO_' + current_time + '.csv', 'w', encoding='utf-8', newline='') train_logger = csv.writer(train_log) eval_logger = csv.writer(eval_log) num_epochs = max_iter // steps_per_epoch total_t = 0 for epoch in range(num_epochs): # start agent-env interaction state = env.reset() step_count = 0 ep_reward = 0 for t in range(steps_per_epoch): # collect transition samples by executing the policy action, log_prob, v = agent.get_action(state) next_state, reward, done, _ = env.step(action) agent.Memory.append(state, action, reward, v, log_prob) ep_reward += reward step_count += 1 if (step_count == max_ep_len) or (t == steps_per_epoch - 1): # termination of env by env wrapper, or by truncation due to memory size s_last = torch.tensor(next_state, dtype=torch.float).to(device) v_last = agent.V(s_last).item() agent.Memory.compute_values(v_last) elif done: # episode done as the agent reach a terminal state v_last = 0.0 agent.Memory.compute_values(v_last) state = next_state if done: train_logger.writerow([total_t, ep_reward]) state = env.reset() step_count = 0 ep_reward = 0 if total_t % eval_interval == 0: log = agent.eval(env_id, total_t) eval_logger.writerow(log) total_t += 1 # train agent at the end of each epoch agent.train(num_iter=1) train_log.close() eval_log.close() return
def main(index, args): device = xm.xla_device() gen_net = Generator(args).to(device) dis_net = Discriminator(args).to(device) enc_net = Encoder(args).to(device) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) enc_net.apply(weights_init) ae_recon_optimizer = torch.optim.Adam( itertools.chain(enc_net.parameters(), gen_net.parameters()), args.ae_recon_lr, (args.beta1, args.beta2)) ae_reg_optimizer = torch.optim.Adam( itertools.chain(enc_net.parameters(), gen_net.parameters()), args.ae_reg_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr, (args.beta1, args.beta2)) gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr, (args.beta1, args.beta2)) dataset = datasets.ImageDataset(args) train_loader = dataset.train valid_loader = dataset.valid para_loader = pl.ParallelLoader(train_loader, [device]) fid_stat = str(pathlib.Path( __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz' if not os.path.exists(fid_stat): download_stat_cifar10_test() is_best = True args.num_epochs = np.ceil(args.num_iter / len(train_loader)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0, args.num_iter / 2, args.num_iter) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0, args.num_iter / 2, args.num_iter) ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0, args.num_iter / 2, args.num_iter) ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0, args.num_iter / 2, args.num_iter) # initial start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) enc_net.load_state_dict(checkpoint['enc_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer']) ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer']) args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs' args.path_helper = set_log_dir(logs_dir, args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler, ae_reg_scheduler) train(device, args, gen_net, dis_net, enc_net, gen_optimizer, dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1: fid_score = validate(args, fid_stat, gen_net, writer_dict, valid_loader) logger.info(f'FID score: {fid_score} || @ epoch {epoch}.') if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False save_checkpoint( { 'epoch': epoch + 1, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'enc_state_dict': enc_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'ae_recon_optimizer': ae_recon_optimizer.state_dict(), 'ae_reg_optimizer': ae_reg_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path'])
def run_ddpg( env_id, gamma=0.99, actor_lr=1e-4, critic_lr=1e-3, tau=1e-3, sigma=0.1, hidden_size1=64, hidden_size2=64, max_iter=1e5, eval_interval=1000, start_train=10000, train_interval=50, buffer_size=1e5, fill_buffer=20000, batch_size=64, num_checkpoints=5, render=False, ): args = locals() max_iter = int(max_iter) buffer_size = int(buffer_size) checkpoint_interval = max_iter // (num_checkpoints - 1) env = gym.make(env_id) dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env) agent = DDPGAgent(dimS, dimA, gamma=gamma, actor_lr=actor_lr, critic_lr=critic_lr, tau=tau, sigma=sigma, hidden_size1=hidden_size1, hidden_size2=hidden_size2, buffer_size=buffer_size, batch_size=batch_size, render=render) set_log_dir(env_id) current_time = time.strftime("%m%d-%H%M%S") train_log = open('./train_log/' + env_id + '/DDPG_' + current_time + '.csv', 'w', encoding='utf-8', newline='') eval_log = open('./eval_log/' + env_id + '/DDPG_' + current_time + '.csv', 'w', encoding='utf-8', newline='') train_logger = csv.writer(train_log) eval_logger = csv.writer(eval_log) with open('./eval_log/' + env_id + '/DDPG_' + current_time + '.txt', 'w') as f: for key, val in args.items(): print(key, '=', val, file=f) state = env.reset() step_count = 0 ep_reward = 0 # main loop for t in range(max_iter + 1): if t < fill_buffer: action = env.action_space.sample() else: action = agent.get_action(state) next_state, reward, done, _ = env.step(action) step_count += 1 if step_count == max_ep_len: done = False agent.buffer.append(state, action, reward, next_state, done) state = next_state ep_reward += reward if done or (step_count == max_ep_len): train_logger.writerow([t, ep_reward]) state = env.reset() step_count = 0 ep_reward = 0 if (t >= start_train) and (t % train_interval == 0): for _ in range(train_interval): agent.train() if t % eval_interval == 0: eval_score = eval_agent(agent, env_id, render=False) log = [t, eval_score] print('step {} : {:.4f}'.format(t, eval_score)) eval_logger.writerow(log) if t % (10 * eval_interval) == 0: if render: render_agent(agent, env_id) if t % checkpoint_interval == 0: agent.save_model('./checkpoints/' + env_id + '/DDPG(iter={})'.format(t)) train_log.close() eval_log.close() return
def run_dqn(env_id, gamma=0.99999, lr=1e-4, polyak=1e-3, hidden1=256, hidden2=256, num_ep=2e3, buffer_size=1e6, fill_buffer=20000, batch_size=128, train_interval=50, start_train=10000, eval_interval=20, device='cuda', render=False): arg_dict = locals() num_ep = int(num_ep) buffer_size = int(buffer_size) env = gym.make(env_id) dimS = env.observation_space.shape[0] # dimension of state space nA = env.action_space.n # number of actions # (physical) length of the time horizon of each truncated episode # each episode run for t \in [0, T) # set for RL in semi-MDP setting T = 3000000 agent = SemiDQNAgent(dimS, nA, env.action_map_no_wt, gamma, hidden1, hidden2, lr, polyak, buffer_size, batch_size, device=device, render=render) # log setting set_log_dir(env_id) current_time = time.strftime("%m%d-%H%M%S") log_file = open('./log/' + env_id + '/semiDQN_' + current_time + '.csv', 'w', encoding='utf-8', newline='') logger = csv.writer(log_file) with open('./log/' + env_id + '/semiDQN_' + current_time + '.txt', 'w') as f: for key, val in arg_dict.items(): print(key, '=', val, file=f) # start environment roll-out max_epsilon = 1. min_epsilon = 0.02 # linearly scheduled $\epsilon$ exploration_schedule = LinearSchedule(begin_t=0, end_t=num_ep, begin_value=max_epsilon, end_value=min_epsilon) carried = None for i in range(num_ep): s = env.reset() t = 0. # physical elapsed time of the present episode ep_reward = 0. epsilon = exploration_schedule(i) while t < T: # print('t = {:.4f}'.format(t)) a = agent.get_action(s, epsilon) s_next, r, d, info = env.step(a) ep_reward += gamma**t * r dt = info['dt'] # carried = info['carried'] # t = info['elapsed time'] # TODO : expand prioritized replay buffer agent.buffer.append(s, a, r, s_next, False, dt) agent.train() s = s_next # please don't forget this...please... log_time = datetime.datetime.now(tz=None).strftime("%Y-%m-%d %H:%M:%S") print( '{} (episode {} / epsilon = {:.2f}) reward = {:.4f} | carried = {}' .format(log_time, i, epsilon, ep_reward, carried)) logger.writerow([i, ep_reward, carried]) log_file.close() """ if i % eval_interval == 0: eval_score = eval_agent(agent, env_id, render=False) log = [i, eval_score] print('step {} : {:.4f}'.format(i, eval_score)) eval_logger.writerow(log) """ return
def run_sac(env_id, max_iter=1e6, eval_interval=2000, start_train=10000, train_interval=50, buffer_size=1e6, fill_buffer=20000, truncate=1000, gamma=0.99, pi_lr=3e-4, q_lr=3e-4, polyak=5e-3, alpha=0.2, hidden1=256, hidden2=256, batch_size=128, device='cpu', render='False'): params = locals() max_iter = int(max_iter) buffer_size = int(buffer_size) env = gym.make(env_id) dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env) if truncate is not None: max_ep_len = truncate agent = SACAgent(dimS, dimA, ctrl_range, gamma=gamma, pi_lr=pi_lr, q_lr=q_lr, polyak=polyak, alpha=alpha, hidden1=hidden1, hidden2=hidden2, buffer_size=buffer_size, batch_size=batch_size, device=device, render=render) set_log_dir(env_id) num_checkpoints = 5 checkpoint_interval = max_iter // (num_checkpoints - 1) current_time = time.strftime("%m%d-%H%M%S") train_log = open('./train_log/' + env_id + '/SAC_' + current_time + '.csv', 'w', encoding='utf-8', newline='') path = './eval_log/' + env_id + '/SAC_' + current_time eval_log = open(path + '.csv', 'w', encoding='utf-8', newline='') train_logger = csv.writer(train_log) eval_logger = csv.writer(eval_log) with open(path + '.txt', 'w') as f: for key, val in params.items(): print(key, '=', val, file=f) obs = env.reset() step_count = 0 ep_reward = 0 # main loop start = time.time() for t in range(max_iter + 1): if t < fill_buffer: action = env.action_space.sample() else: action = agent.act(obs) next_obs, reward, done, _ = env.step(action) step_count += 1 if step_count == max_ep_len: done = False agent.buffer.append(obs, action, next_obs, reward, done) obs = next_obs ep_reward += reward if done or (step_count == max_ep_len): train_logger.writerow([t, ep_reward]) obs = env.reset() step_count = 0 ep_reward = 0 if (t >= start_train) and (t % train_interval == 0): for _ in range(train_interval): agent.train() if t % eval_interval == 0: eval_score = eval_agent(agent, env_id, render=False) log = [t, eval_score] print('step {} : {:.4f}'.format(t, eval_score)) eval_logger.writerow(log) if t % (10 * eval_interval) == 0: if render: render_agent(agent, env_id) if t % checkpoint_interval == 0: agent.save_model('./checkpoints/' + env_id + '/sac_{}th_iter_'.format(t)) train_log.close() eval_log.close() return
def run_sac(env_id, num_ep=2e3, T=300, eval_interval=2000, start_train=10000, train_interval=50, buffer_size=1e6, fill_buffer=20000, gamma=0.99, pi_lr=3e-4, q_lr=3e-4, alpha_lr=3e-4, polyak=5e-3, adjust_entropy=False, alpha=0.2, target_entropy=-6.0, hidden1=256, hidden2=256, batch_size=64, pth=None, device='cpu', render='False'): arg_dict = locals() num_ep = int(num_ep) buffer_size = int(buffer_size) env = gym.make(env_id) test_env = gym.make(env_id) dimS = env.observation_space.shape[0] # dimension of state space nA = env.action_space.n # number of actions # (physical) length of the time horizon of each truncated episode # each episode run for t \in [0, T) # set for RL in semi-MDP setting agent = SACAgent(dimS, nA, env.action_map, gamma, pi_lr=pi_lr, q_lr=q_lr, alpha_lr=alpha_lr, polyak=polyak, adjust_entropy=adjust_entropy, target_entropy=target_entropy, alpha=alpha, hidden1=hidden1, hidden2=hidden2, buffer_size=buffer_size, batch_size=batch_size, device=device, render=render) # log setting set_log_dir(env_id) current_time = time.strftime("%m%d-%H%M%S") if pth is None: # default location of directory for training log pth = './log/' + env_id + '/' os.makedirs(pth, exist_ok=True) current_time = time.strftime("%m_%d-%H%_M_%S") file_name = pth + 'prioritized_' + current_time log_file = open(file_name + '.csv', 'w', encoding='utf-8', newline='') eval_log_file = open(file_name + '_eval.csv', 'w', encoding='utf-8', newline='') logger = csv.writer(log_file) eval_logger = csv.writer(eval_log_file) # save parameter configuration with open('./log/' + env_id + '/SAC_' + current_time + '.txt', 'w') as f: for key, val in arg_dict.items(): print(key, '=', val, file=f) OPERATION_HOUR = T * num_ep # 200 evaluations in total EVALUATION_INTERVAL = OPERATION_HOUR / 200 evaluation_count = 0 # start environment roll-out global_t = 0. counter = 0 for i in range(num_ep): if global_t >= OPERATION_HOUR: break # initialize an episode s = env.reset() t = 0. # physical elapsed time of the present episode info = None ep_reward = 0. while t < T: # evaluation is done periodically if evaluation_count * EVALUATION_INTERVAL <= global_t: result = agent.eval(test_env, T=14400, eval_num=3) log = [i] + result eval_logger.writerow(log) evaluation_count += 1 if counter < fill_buffer: a = random.choice(env.action_map(s)) else: a = agent.get_action(s) s_next, r, d, info = env.step(a) ep_reward += gamma**t * r dt = info['dt'] t = info['elapsed_time'] global_t += dt counter += 1 agent.buffer.append(s, a, r, s_next, False, dt) if counter >= start_train and counter % train_interval == 0: # training stage # single step per one transition observation for _ in range(train_interval): agent.train() s = s_next # save training statistics log_time = datetime.now(tz=None).strftime("%Y-%m-%d %H:%M:%S") op_log = env.operation_log print('+' + '=' * 78 + '+') print('+' + '-' * 31 + 'TRAIN-STATISTICS' + '-' * 31 + '+') print('{} (episode {}) reward = {:.4f}'.format(log_time, i, ep_reward)) print('+' + '-' * 32 + 'FAB-STATISTICS' + '-' * 32 + '+') print( 'carried = {}/{}\n'.format(op_log['carried'], sum(op_log['total'])) + 'remain_quantity : {}\n'.format(op_log['waiting_quantity']) + 'visit_count : {}\n'.format(op_log['visit_count']) + 'load_two : {}\n'.format(op_log['load_two']) + 'unload_two : {}\n'.format(op_log['unload_two']) + 'load_sequential : {}\n'.format(op_log['load_sequential']) + 'total : ', op_log['total']) print('+' + '=' * 78 + '+') print('\n', end='') logger.writerow([i, ep_reward, op_log['carried']] + op_log['waiting_quantity'] + list(op_log['visit_count']) + [ op_log['load_two'], op_log['unload_two'], op_log['load_sequential'] ] + list(op_log['total'])) log_file.close() eval_log_file.close() return