def main():
    args = parser.parse_args()

    # Seq sequence length & visualization_num
    args.seq_len = args.target_num if args.seq_len is None else args.seq_len
    args.visualization_dir = os.path.join('exp', args.exp, 'visualization')
    utils.mkdir(args.visualization_dir)

    # Set exp directory and tensorboard writer
    writer_dir = os.path.join('exp', args.exp)
    utils.mkdir(writer_dir)
    writer = SummaryWriter(writer_dir)

    # Save arguments
    str_list = []
    for key in vars(args):
        print('[{0}] = {1}'.format(key, getattr(args, key)))
        str_list.append('--{0}={1} \\'.format(key, getattr(args, key)))
    with open(os.path.join('exp', args.exp, 'args.txt'), 'w+') as f:
        f.write('\n'.join(str_list))

    # Set directory. e.g. replay buffer, visualization, model snapshot
    args.replay_buffer_dir = os.path.join('exp', args.exp, 'replay_buffer')
    args.model_dir = os.path.join('exp', args.exp, 'models')
    utils.mkdir(args.model_dir)

    # Reset random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Set device
    device = torch.device('cpu') if args.gpu == '-1' else torch.device(
        f'cuda:{args.gpu}')

    # Set replay buffer
    replay_buffer = ReplayBuffer(args.replay_buffer_dir,
                                 args.replay_buffer_size)
    if args.load_replay_buffer is not None:
        print(f'==> Loading replay buffer from {args.load_replay_buffer}')
        replay_buffer.load(
            os.path.join('exp', args.load_replay_buffer, 'replay_buffer'))
        print(
            f'==> Loaded replay buffer from {args.load_replay_buffer} [size = {replay_buffer.length}]'
        )

    # Set model and optimizer
    if args.model_type == 'adagrasp':
        model = GraspingModel(num_rotations=args.num_rotations,
                              gripper_final_state=True)
    elif args.model_type == 'adagrasp_init_only':
        model = GraspingModel(num_rotations=args.num_rotations,
                              gripper_final_state=False)
    elif args.model_type == 'scene_only':
        model = GraspingModelSceneOnly(num_rotations=args.num_rotations,
                                       gripper_final_state=True)
    else:
        raise NotImplementedError(f'Does not support {args.model_type}')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(0.9, 0.95))
    model = model.to(device)

    #check cuda memory allocation
    if args.gpu != '-1':
        bytes_allocated = torch.cuda.memory_allocated(device)
        print("Model size: {:.3f} MB".format(bytes_allocated / (1024**2)))

    # load checkpoint
    if args.load_checkpoint is not None:
        print(f'==> Loading checkpoint from {args.load_checkpoint}')
        if args.load_checkpoint.endswith('.pth'):
            checkpoint = torch.load(args.load_checkpoint, map_location=device)
        else:
            checkpoint = torch.load(os.path.join('exp', args.load_checkpoint,
                                                 'models', 'latest.pth'),
                                    map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint[
            'epoch'] if args.load_replay_buffer is not None else 0
        print(f'==> Loaded checkpoint from {args.load_checkpoint}')
    else:
        start_epoch = 0

    # launch processes for each env
    args.num_envs = 1 if args.gui else args.num_envs
    processes, conns = [], []
    ctx = mp.get_context('spawn')
    for rank in range(args.num_envs):
        conn_main, conn_env = ctx.Pipe()
        reset_args = {
            'num_open_scale': args.num_open_scale,
            'max_open_scale': args.max_open_scale,
            'min_open_scale': args.min_open_scale,
            'gripper_final_state': args.model_type == 'adagrasp',
            'target_num': args.target_num,
            'obstacle_num': args.obstacle_num
        }
        p = ctx.Process(target=env_process,
                        args=(rank, start_epoch + args.seed + rank, conn_env,
                              args.gui, args.num_cam, args.seq_len,
                              reset_args))
        p.daemon = True
        p.start()
        processes.append(p)
        conns.append(conn_main)

    # Initialize exit signal handler (for graceful exits)
    def save_and_exit(signal, frame):
        print('Warning: keyboard interrupt! Cleaning up...')
        for p in processes:
            p.terminate()
        replay_buffer.dump()
        writer.close()
        print('Finished. Now exiting gracefully.')
        sys.exit(0)

    signal.signal(signal.SIGINT, save_and_exit)

    for epoch in range(start_epoch, args.epoch):
        print(f'---------- epoch-{epoch + 1} ----------')
        timestamp = time.time()

        assert args.min_epsilon <= args.max_epsilon
        m1, m2 = args.min_epsilon, args.max_epsilon
        epsilon = max(m1, m2 - (m2 - m1) * epoch / args.exploration_epoch)
        # Data collection
        data = collect_data(
            conns,
            model,
            device,
            n_steps=1,
            epsilon=epsilon,
            gripper_final_state=(args.model_type == 'adagrasp'))

        for d in data.values():
            replay_buffer.save_data(d)

        average_reward = np.mean([d['reward'] for d in data.values()])
        average_score = np.mean([d['score'] for d in data.values()])
        print(
            f'Mean reward = {average_reward:.3f}, Mean score = {average_score:.3f}'
        )
        writer.add_scalar('Data Collection/Reward', average_reward, epoch + 1)
        writer.add_scalar('Data Collection/Score', average_score, epoch + 1)

        time_data_collection = time.time() - timestamp

        # Replay buffer statistic
        reward_data = np.array(replay_buffer.scalar_data['reward'])
        print(
            f'Replay buffer size = {len(reward_data)} (positive = {len(np.argwhere(reward_data == 1))}, negative = {len(np.argwhere(reward_data == 0))})'
        )

        # Policy training
        model.train()
        torch.set_grad_enabled(True)
        sum_loss = 0
        score_statics = {'positive': list(), 'negative': list()}
        for _ in range(args.iter_per_epoch):
            iter_loss, iter_score_statics = train(
                model,
                device,
                replay_buffer,
                optimizer,
                args.batch_size,
                gripper_final_state=(args.model_type == 'adagrasp'))
            sum_loss += iter_loss
            score_statics['positive'].append(iter_score_statics[1])
            score_statics['negative'].append(iter_score_statics[0])
        average_loss = sum_loss / args.iter_per_epoch
        positive_score_prediction = np.mean(score_statics['positive'])
        negative_score_prediction = np.mean(score_statics['negative'])
        print(
            f'Training loss = {average_loss:.5f}, positive_mean = {positive_score_prediction:.3f}, negative_mean = {negative_score_prediction:.3f}'
        )
        writer.add_scalar('Policy Training/Loss', average_loss, epoch + 1)
        writer.add_scalar('Policy Training/Positive Score Prediction',
                          positive_score_prediction, epoch + 1)
        writer.add_scalar('Policy Training/Negative Score Prediction',
                          negative_score_prediction, epoch + 1)

        # Save model and optimizer
        if (epoch + 1) % args.snapshot_gap == 0:
            model.eval()
            torch.set_grad_enabled(False)

            # Visualization
            [conn.send("reset") for conn in conns]
            data = collect_data(
                conns,
                model,
                device,
                n_steps=args.seq_len,
                epsilon=0,
                gripper_final_state=(args.model_type == 'adagrasp'))

            vis_path = os.path.join(args.visualization_dir,
                                    'epoch_%06d' % (epoch + 1))
            utils.visualization(data, args.num_envs, args.seq_len,
                                args.num_open_scale, args.num_rotations,
                                args.num_vis, vis_path)

            save_state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1
            }
            torch.save(save_state, os.path.join(args.model_dir, 'latest.pth'))
            shutil.copyfile(
                os.path.join(args.model_dir, 'latest.pth'),
                os.path.join(args.model_dir, 'epoch_%06d.pth' % (epoch + 1)))

            # Save replay buffer
            replay_buffer.dump()

        # Print elapsed time for an epoch
        time_all = time.time() - timestamp
        time_training = time_all - time_data_collection
        print(
            f'Elapsed time = {time_all:.2f}: (collect) {time_data_collection:.2f} + (train) {time_training:.2f}'
        )

    save_and_exit(None, None)