Пример #1
0
    netD.cuda()
    netG.cuda()
    input = input.cuda()
    one, mone = one.cuda(), mone.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

# setup optimizer
if opt.adam:
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lrD,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lrG,
                            betas=(opt.beta1, 0.999))
else:
    optimizerD = optim.RMSprop(netD.parameters(), lr=opt.lrD)
    optimizerG = optim.RMSprop(netG.parameters(), lr=opt.lrG)

logger = Logger(model_name='LRPGAN', data_name='lsun')

gen_iterations = 0
for epoch in range(opt.niter):
    data_iter = iter(dataloader)
    i = 0
    while i < len(dataloader):
        ############################
        # (1) Update D network
        ###########################
        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
Пример #2
0
Файл: paac.py Проект: gikr/t_lab
    def __init__(self, network_creator, batch_env, args):
        logging.debug('PAAC init is started')
        self.args = copy.copy(vars(args))
        self.checkpoint_dir = join_path(self.args['debugging_folder'],
                                        self.CHECKPOINT_SUBDIR)
        ensure_dir(self.checkpoint_dir)

        checkpoint = self._load_latest_checkpoint(self.checkpoint_dir)
        self.last_saving_step = checkpoint['last_step'] if checkpoint else 0

        self.final_rewards = []
        self.global_step = self.last_saving_step
        self.network = network_creator()
        self.batch_env = batch_env
        self.optimizer = optim.RMSprop(
            self.network.parameters(),
            lr=self.args['initial_lr'],
            eps=self.args['e'],
        )  #RMSprop defualts: momentum=0., centered=False, weight_decay=0

        self.length_previous = [10, 15]

        if checkpoint:
            logging.info('Restoring agent variables from previous run')
            self.network.load_state_dict(checkpoint['network_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        self.lr_scheduler = LinearAnnealingLR(self.optimizer,
                                              self.args['lr_annealing_steps'])
        #pytorch documentation says:
        #In most cases it’s better to use CUDA_VISIBLE_DEVICES environmental variable
        #Therefore to specify a particular gpu one should use CUDA_VISIBLE_DEVICES.
        self.use_cuda = self.args['device'] == 'gpu'
        self.use_rnn = hasattr(
            self.network, 'get_initial_state'
        )  #get_initial_state should return state of the rnn layers
        self._tensors = torch.cuda if self.use_cuda else torch

        self.action_codes = np.eye(
            batch_env.num_actions)  #envs reveive actions in one-hot encoding!
        self.gamma = self.args['gamma']  # future rewards discount factor
        self.entropy_coef = self.args['entropy_regularisation_strength']
        self.loss_scaling = self.args['loss_scaling']  #5.
        self.critic_coef = self.args['critic_coef']  #0.25
        self.eval_func = None

        #new variables
        self.success_time = 0
        self.rewards_deque = deque(maxlen=64)

        if self.args['clip_norm_type'] == 'global':
            self.clip_gradients = nn.utils.clip_grad_norm_
        elif self.args['clip_norm_type'] == 'local':
            self.clip_gradients = utils.clip_local_grad_norm
        elif self.args['clip_norm_type'] == 'ignore':
            self.clip_gradients = lambda params, _: utils.global_grad_norm(
                params)
        else:
            raise ValueError('Norm type({}) is not recoginized'.format(
                self.args['clip_norm_type']))
        logging.debug('Paac init is done')
Пример #3
0
def train(dataset, graphs, config):

    if os.path.exists(
            cmd_args.model_path) and os.path.getsize(cmd_args.model_path) > 0:
        model = torch.load(cmd_args.model_path)
        DC.steps_done = model.steps_done
    else:
        decoder_name = "GlobalDecoder"
        policy = DQPolicy(dataset, decoder_name)
        decoder = policy.decoder
        current_it = 0
        memory = ReplayMemory(10000)
        optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr)
        DC.steps_done = 0
        model = Learning_Model(decoder, policy, memory, optimizer, current_it,
                               DC.steps_done)

    decoder_name = str(type(model.policy.decoder).__name__)
    target = DQPolicy(dataset, decoder_name)
    target_decoder = target.decoder
    target.load_state_dict(model.policy.state_dict())
    target.eval()

    data_loader = DataLoader(dataset)

    for it in range(cmd_args.episode_iter):
        logging.info(f"training iteration: {it}")
        success_ct = 0
        total_loss = 0.0
        total_ct = 0

        if model.current_it > it:
            continue

        for data_point, ct in zip(data_loader, tqdm(range(len(data_loader)))):

            logging.info(f"task ct: {ct}")

            graph = graphs[data_point.graph_id]
            suc = episode(model.policy, target, data_point, graph, config,
                          dataset.attr_encoder, model.memory, total_ct,
                          model.optimizer)

            total_ct += 1
            if suc:
                success_ct += 1

        logging.info(f"success count: {success_ct}")

        # if it % cmd_args.save_num == 0:
        #     model.steps_done = DC.steps_done
        #     model.current_it = it
        #     torch.save(model, cmd_args.model_path)

        if it % cmd_args.save_num == 0:
            model.steps_done = DC.steps_done
            # model.eps = eps
            model.current_it = it
            model_name = f"model_{it}.pkl"
            model_path = os.path.join(cmd_args.model_save_dir, model_name)
            torch.save(model, model_path)

    print('Complete')
Пример #4
0
# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logs = open(log_path, 'w')

transform = T.Compose([T.ToPILImage(), T.ToTensor()])

policy_net = dqn.DQN(n_angle, n_actions, hidden_layer1_size,
                     hidden_layer2_size).to(device)
target_net = dqn.DQN(n_angle, n_actions, hidden_layer1_size,
                     hidden_layer2_size).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = utils.ReplayMemory(10000)

steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
Пример #5
0
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 20)
        self.fc4 = nn.Linear(20, nb_movies)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fc3(x))
        x = self.fc4(x)
        return x


sae = SAE()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(sae.parameters(), lr=0.01, weight_decay=0.5)

# Training the Auto Encoders
nb_epoch = 200
for epoch in range(1, nb_epoch + 1):
    train_loss = 0
    s = 0.
    for id_user in range(nb_users):
        input = Variable(training_set[id_user]).unsqueeze(0)
        target = input.clone()
        if torch.sum(target.data > 0) > 0:
            output = sae(input)
            target.require_grad = False
            output[target == 0] = 0
            loss = criterion(output, target)
            mean_corrector = nb_movies / float(
Пример #6
0
    def train(self):
        # initialize epochs
        ep = 0

        # create data loader
        print('getting ind matrix')
        ind_mat = self.matrix_to_pixel_frame_target(self.matrix2d)
        print('loading dataset')
        trainloader = self.load_dataset(ind_mat, batch_size=self.batchsize)

        # initialize network parameters and send to gpu
        if self.nmf_init:
            model = NMF(n_components=2,
                        init='random',
                        random_state=0,
                        max_iter=200,
                        tol=0.0001)
            W = model.fit_transform(self.matrix2d)
            H = model.components_
            self.nnmf.init_params(NMF=(W, H))
        else:
            self.nnmf.init_params()

        self.nnmf.to(self.device)

        # optimizers for mlp and latent features

        if self.d != 0:
            latent_params = list(self.nnmf.U.parameters()) + list(self.nnmf.Uprime_1.parameters()) + \
                            list(self.nnmf.Uprime_2.parameters()) + list(self.nnmf.V.parameters()) + \
                            list(self.nnmf.Vprime_1.parameters()) + list(self.nnmf.Vprime_2.parameters())
        else:
            latent_params = list(self.nnmf.Uprime_1.parameters()) + list(self.nnmf.Uprime_2.parameters()) + \
                            list(self.nnmf.Vprime_1.parameters()) + list(self.nnmf.Vprime_2.parameters())  # + \

        latent_opt = optim.RMSprop(latent_params, lr=self.lr)
        # latent_opt = optim.Adam(latent_params, lr=self.lr)

        mlp_opt = optim.RMSprop(self.nnmf.mlp.parameters(), lr=self.lr)
        # mlp_opt = optim.Adam(self.nnmf.mlp.parameters(), lr=self.lr)

        mlp_s_opt = optim.RMSprop(self.nnmf.mlp_s.parameters(), lr=self.lr)
        # mlp_s_opt = optim.Adam(self.nnmf.mlp_s.parameters(), lr=self.lr)

        print('beginning training')

        while ep < self.epochs:

            for batch_id, batch in enumerate(trainloader, 0):

                # get data from batch
                pixel, frame, target = Variable(batch[0]), Variable(
                    batch[1]), Variable(batch[2])

                # send x_hat and s to gpu
                self.x_hat = self.x_hat.to(self.device)
                self.s = self.s.to(self.device)

                # send to gpu
                pixel = pixel.to(self.device)
                frame = frame.to(self.device)
                target = target.to(self.device).float()
                target = torch.reshape(target, shape=(target.shape[0], 1))

                if ep >= 15:
                    ##################################################################
                    # train mlp_s weights
                    mlp_s_opt.zero_grad()
                    x_out, s_out = self.nnmf.forward(pixel, frame, target)

                    # calculate loss on mlp_s and update mlp_s weights
                    mse_loss = self.mse_loss(target, x_out + s_out)
                    l1_loss = self.l1_loss(s_out)
                    if ep >= 20:
                        loss_mlp = mse_loss + l1_loss
                    else:
                        loss_mlp = mse_loss
                    loss_mlp.backward()
                    mlp_s_opt.step()

                ##################################################################
                # train mlp weights
                mlp_opt.zero_grad()
                x_out, s_out = self.nnmf.forward(pixel, frame, target)

                # calculate loss on mlp and update mlp weights
                if ep >= 15:
                    mse_loss = self.mse_loss(target, x_out + s_out)
                else:
                    mse_loss = self.mse_loss(target, x_out)
                lat_loss = self.lat_loss(pixel, frame)
                loss_mlp_x = mse_loss + lat_loss
                loss_mlp_x.backward()
                mlp_opt.step()

                ##################################################################
                # train latent weights
                latent_opt.zero_grad()
                x_out, s_out = self.nnmf.forward(pixel, frame, target)

                # calculate loss on latent and update latent weights
                if ep >= 15:
                    mse_loss = self.mse_loss(target, x_out + s_out)
                else:
                    mse_loss = self.mse_loss(target, x_out)
                    l1_loss = 0
                lat_loss = self.lat_loss(pixel, frame)
                loss_latent = mse_loss + lat_loss
                loss_latent.backward()
                latent_opt.step()

                # update x_hat and s
                self.x_hat[pixel, frame] = torch.squeeze(x_out.detach())
                self.s[pixel, frame] = torch.squeeze(s_out.detach())

                # print loss
                if batch_id % 10 == 0:
                    print(
                        '[%d] MSE Loss: %.5f | Latent Loss: %.5f | L1 Loss: %.5f'
                        % (ep + 1, mse_loss, lat_loss, l1_loss))

                # save x_hat and s to directory
            dir = '/local/home/jprovost/echo/out/mitral_valve/nnmf/' + self.save_loc + '/'
            print('Saving to dir:', dir)
            try:
                os.mkdir(dir)
                torch.save(self.x_hat.cpu(), dir + 'x_hat.pt')
                torch.save(self.s.cpu(), dir + 's.pt')
                torch.save(self.nnmf.Uprime_1.weight.cpu(), dir + 'Uprime1.pt')
                torch.save(self.nnmf.Uprime_2.weight.cpu(), dir + 'Uprime2.pt')
            except:
                torch.save(self.x_hat.cpu(), dir + 'x_hat.pt')
                torch.save(self.s.cpu(), dir + 's.pt')
                torch.save(self.nnmf.Uprime_1.weight.cpu(), dir + 'Uprime1.pt')
                torch.save(self.nnmf.Uprime_2.weight.cpu(), dir + 'Uprime2.pt')

            # increase epoch
            ep += 1
Пример #7
0
image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.averager()

# setup optimizer
if opt.adam:
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
                           betas=(opt.beta1, 0.999))
elif opt.adadelta:
    optimizer = optim.Adadelta(crnn.parameters())
else:
    optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)


def val(net, dataset, criterion, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(
        dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
Пример #8
0
    def _set_optimizer(self, model):
        optimizer_config = self.config["optimizer_config"]
        opt = optimizer_config["optimizer"]

        parameters = filter(lambda p: p.requires_grad, model.parameters())

        # Special optimizer for fp16
        if model.config["fp16"]:
            from apex.optimizers import FP16_Optimizer, FusedAdam

            class FP16_OptimizerMMTLModified(FP16_Optimizer):
                def step(self, closure=None):
                    """
                    Not supporting closure.
                    """
                    # First compute norm for all group so we know if there is overflow
                    grads_groups_flat = []
                    norm_groups = []
                    skip = False
                    for i, group in enumerate(self.fp16_groups):

                        # Only part that's changed -- zero out grads that are None
                        grads_to_use = []
                        for p in group:
                            if p.grad is None:
                                size = list(p.size())
                                grads_to_use.append(p.new_zeros(size))
                            else:
                                grads_to_use.append(p.grad)
                        grads_groups_flat.append(
                            _flatten_dense_tensors(grads_to_use))

                        norm_groups.append(
                            self._compute_grad_norm(grads_groups_flat[i]))
                        if norm_groups[i] == -1:  # TODO: early break
                            skip = True

                    if skip:
                        self._update_scale(skip)
                        return

                    # norm is in fact norm*cur_scale
                    self.optimizer.step(
                        grads=[[g] for g in grads_groups_flat],
                        output_params=[[p] for p in self.fp16_groups_flat],
                        scale=self.cur_scale,
                        grad_norms=norm_groups,
                    )

                    # TODO: This may not be necessary; confirm if it is
                    for i in range(len(norm_groups)):
                        updated_params = _unflatten_dense_tensors(
                            self.fp16_groups_flat[i], self.fp16_groups[i])
                        for p, q in zip(self.fp16_groups[i], updated_params):
                            p.data = q.data

                    self._update_scale(False)
                    return

            optimizer = FusedAdam(
                parameters,
                **optimizer_config["optimizer_common"],
                bias_correction=False,
                max_grad_norm=1.0,
            )
            optimizer = FP16_OptimizerMMTLModified(optimizer,
                                                   dynamic_loss_scale=True)

        elif opt == "sgd":
            optimizer = optim.SGD(
                parameters,
                **optimizer_config["optimizer_common"],
                **optimizer_config["sgd_config"],
                weight_decay=self.config["l2"],
            )
        elif opt == "rmsprop":
            optimizer = optim.RMSprop(
                parameters,
                **optimizer_config["optimizer_common"],
                **optimizer_config["rmsprop_config"],
                weight_decay=self.config["l2"],
            )
        elif opt == "adam":
            optimizer = optim.Adam(
                parameters,
                **optimizer_config["optimizer_common"],
                **optimizer_config["adam_config"],
                weight_decay=self.config["l2"],
            )
        elif opt == "adamax":
            optimizer = optim.Adamax(
                parameters,
                **optimizer_config["optimizer_common"],
                **optimizer_config["adam_config"],
                weight_decay=self.config["l2"],
            )
        elif opt == "sparseadam":
            optimizer = optim.SparseAdam(
                parameters,
                **optimizer_config["optimizer_common"],
                **optimizer_config["adam_config"],
            )
            if self.config["l2"]:
                raise Exception(
                    "SparseAdam optimizer does not support weight_decay (l2 penalty)."
                )
        else:
            raise ValueError(f"Did not recognize optimizer option '{opt}'")
        self.optimizer = optimizer
Пример #9
0
train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=8)
val_loader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=8)

model = DeepLabv3(n_class)

if use_gpu:
    model.cuda()

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.RMSprop(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=w_decay)
scheduler = lr_scheduler.StepLR(
    optimizer, step_size=step_size,
    gamma=gamma)  # decay LR by a factor of 0.5 every 30 epochs

IU_scores = np.zeros((epochs, n_class))
pixel_scores = np.zeros(epochs)


def train():
    for epoch in range(epochs):
        scheduler.step()

        ts = time.time()
        for iter, batch in enumerate(train_loader):
Пример #10
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device(config.device)
    print('device used:', device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length, \
                                dataset.vocab_size, \
                                lstm_num_hidden=config.lstm_num_hidden, \
                                lstm_num_layers=config.lstm_num_layers, \
                                device=config.device)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate)

    sentences = 'book used:' + config.txt_file
    accuracies = 'book used:' + config.txt_file + '\n'

    for epoch in range(config.epochs):

        sentences += '\n epoch: ' + str(epoch)
        for step, (batch_inputs, batch_targets) in enumerate(data_loader):

            # Only for time measurement of step through network
            t1 = time.time()

            #######################################################
            # Add more code here ...
            #######################################################

            batch_size = batch_inputs[0].size(0)

            x_batch = torch.zeros(config.seq_length, batch_size, \
                                  dataset.vocab_size)
            x_batch.scatter_(2, torch.stack(batch_inputs).unsqueeze_(-1), 1)
            x_batch = x_batch.to(device)
            y_batch = torch.stack(batch_targets).to(device)

            optimizer.zero_grad()
            nn_out, _, _ = model(x_batch)
            loss = criterion(nn_out.view(-1, dataset.vocab_size), \
                             y_batch.view(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), \
                                     max_norm=config.max_norm)
            optimizer.step()

            accuracy = (torch.argmax(nn_out, dim=2) == y_batch).sum().item()\
                        / (batch_size * config.seq_length)

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % config.print_every == 0:

                print("[{}] Train Step {:04d}/{:04d}, Batch Size = {},\
                       Examples/Sec = {:.2f}, "
                      "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                        config.train_steps, config.batch_size, \
                        examples_per_second, accuracy, loss
                ))
                accuracies += str(accuracy) + ", "

            if step % config.sample_every == 0:
                # Generate some sentences by sampling from the model
                char = torch.zeros(1, 1, dataset.vocab_size)
                char[0, 0, np.random.randint(0, dataset.vocab_size)] = 1
                char = char.to(device)
                sentence = generate_sentence(model, dataset, config.seq_length,\
                                             char, device, config.temp)
                sentences += '\n' + str(step) + ' | ' + sentence
                print(sentence)

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break

    print('Done training.')

    with open('sentences/' + config.txt_file[7:] + 'temp' + str(config.temp),
              'w') as f:
        f.write(sentences)
    with open('accuracies/' + config.txt_file[7:] + 'temp' + str(config.temp),
              'w') as f:
        f.write(accuracies)
Пример #11
0
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    config = parse_args()
    trainloader = DataLoader(ScanNet(),
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=1)

    model = Autoencoder().cuda()
    model.load_state_dict(
        torch.load('../DATA/ScanNet_Image/autoencoder{}.pth'.format(2)))
    model.train()
    optimizer = optim.RMSprop(model.parameters(), lr=1e-3)

    for epoch in range(2, config.epochs):
        for batch_idx, data in enumerate(trainloader):
            data = Variable(data).cuda()
            recon_batch = model(data)
            loss = loss_function(recon_batch, data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % config.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader),
Пример #12
0
model.to(device)
model.apply(weights_init)

# Save model dir
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)


# Load saved models
if args.load:
    model.load_state_dict(torch.load(args.load_model))

# Optimizer
if args.optim == 'RMS':
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.9)
elif args.optim == 'ADAM':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)


def load_data():

    dset_train = ReadImagesBaseLine(args.data, train=True, neg_type='CN', pos_type='AD')

    train_loader = DataLoader(dset_train, batch_size=args.batchsize, shuffle=True, **load_args)

    dset_val = ReadImagesBaseLine(args.data, train=False, neg_type='CN', pos_type='AD')

    val_loader = DataLoader(dset_val, batch_size=args.batchsize, shuffle=False, **load_args)

    print("Training Data : ", len(train_loader.dataset))
    parser.add_argument("--env", default=ENV_NAME,
                        help="Name of the environment, default=" + ENV_NAME)
    args = parser.parse_args()
    device = torch.device(GRAPHICS_CARD if args.cuda else "cpu")

    env = wrappers.make_atari(args.env)
    env = wrappers.wrap_deepmind(env, episode_life=False, frame_stack=True)
    exp_buffer = ExperienceBuffer(REPLAY_MEMORY_SIZE)
    agent = Agent(env, exp_buffer)

    net = model.DQN(AGENT_HIST_LENGTH, env.action_space.n).to(device)
    tgt_net = model.DQN(AGENT_HIST_LENGTH, env.action_space.n).to(device)
    tgt_net.load_state_dict(net.state_dict())

    criterion = nn.MSELoss()
    optimizer = optim.RMSprop(net.parameters(), lr=LEARNING_RATE,
                              momentum=GRAD_MOMENTUM, eps=MIN_SQ_GRAD)

    writer = SummaryWriter(comment="-" + args.env)

    remaining_time_buffer = collections.deque(maxlen=100)
    last_100_rewards_training = collections.deque(maxlen=100)
    last_100_rewards_test = collections.deque(maxlen=100)

    episode_idx = 0
    frame_idx = 0
    while frame_idx < MAX_FRAMES:
        episode_t = time.time()
        frame_idx_old = frame_idx
        total_loss = 0.0
        done_reward = None
        while done_reward is None:
Пример #14
0
def init_model(FLAGS,
               logger,
               initial_embeddings,
               vocab_size,
               num_classes,
               data_manager,
               logfile_header=None):
    # Choose model.
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        build_model = spinn.cbow.build_model
    elif FLAGS.model_type == "RNN":
        build_model = spinn.plain_rnn.build_model
    elif FLAGS.model_type == "SPINN":
        build_model = spinn.spinn_core_model.build_model
    elif FLAGS.model_type == "RLSPINN":
        build_model = spinn.rl_spinn.build_model
    elif FLAGS.model_type == "ChoiPyramid":
        build_model = spinn.choi_pyramid.build_model
    else:
        raise NotImplementedError

    build_gtree = spinn.dicriminator.build_goldtree
    build_dis = spinn.dicriminator.build_discriminator

    # Input Encoder.
    context_args = Args()
    context_args.reshape_input = lambda x, batch_size, seq_length: x
    context_args.reshape_context = lambda x, batch_size, seq_length: x
    context_args.input_dim = FLAGS.word_embedding_dim

    if FLAGS.encode == "projection":
        encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim)
        context_args.input_dim = FLAGS.model_dim
    elif FLAGS.encode == "gru":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1
        )  # view map x to shape [batch_size, seq_length, rest]
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = EncodeGRU(FLAGS.word_embedding_dim,
                            FLAGS.model_dim,
                            num_layers=FLAGS.encode_num_layers,
                            bidirectional=FLAGS.encode_bidirectional,
                            reverse=FLAGS.encode_reverse,
                            mix=(FLAGS.model_type != "CBOW"))
    elif FLAGS.encode == "attn":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "pass":

        def encoder(x):
            return x
    else:
        raise NotImplementedError

    context_args.encoder = encoder

    # Composition Function.
    composition_args = Args()
    composition_args.lateral_tracking = FLAGS.lateral_tracking
    composition_args.tracking_ln = FLAGS.tracking_ln
    composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition
    composition_args.size = FLAGS.model_dim
    composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim
    composition_args.use_internal_parser = FLAGS.use_internal_parser
    composition_args.transition_weight = FLAGS.transition_weight
    composition_args.wrap_items = lambda x: torch.cat(x, 0)
    composition_args.extract_h = lambda x: x

    composition_args.detach = FLAGS.transition_detach
    composition_args.evolution = FLAGS.evolution

    if FLAGS.reduce == "treelstm":
        assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.'
        if FLAGS.model_dim != FLAGS.word_embedding_dim:
            print('If you are setting different hidden layer and word '
                  'embedding sizes, make sure you specify an encoder')
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim / 2
        composition = ReduceTreeLSTM(
            FLAGS.model_dim / 2,
            tracker_size=FLAGS.tracking_lstm_hidden_dim,
            use_tracking_in_composition=FLAGS.use_tracking_in_composition,
            composition_ln=FLAGS.composition_ln)
    elif FLAGS.reduce == "tanh":

        class ReduceTanh(nn.Module):
            def forward(self, lefts, rights, tracking=None):
                batch_size = len(lefts)
                ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
                return torch.chunk(ret, batch_size, 0)

        composition = ReduceTanh()
    elif FLAGS.reduce == "treegru":
        composition = ReduceTreeGRU(FLAGS.model_dim,
                                    FLAGS.tracking_lstm_hidden_dim,
                                    FLAGS.use_tracking_in_composition)
    else:
        raise NotImplementedError

    composition_args.composition = composition

    model = build_model(data_manager, initial_embeddings, vocab_size,
                        num_classes, FLAGS, context_args, composition_args)

    gold_tree = build_gtree(data_manager, initial_embeddings, vocab_size,
                            FLAGS, context_args)
    discriminator = build_dis(FLAGS, context_args)

    # Build optimizer.
    if FLAGS.optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(),
                               lr=FLAGS.learning_rate,
                               betas=(0.9, 0.999),
                               eps=1e-08)
        optimizer_D = optim.Adam(discriminator.parameters(),
                                 lr=FLAGS.learning_rate_d,
                                 betas=(0.9, 0.999),
                                 eps=1e-08)
        optimizer_tree = optim.Adam(gold_tree.parameters(),
                                    lr=FLAGS.learning_rate_d,
                                    betas=(0.9, 0.999),
                                    eps=1e-08)
    elif FLAGS.optimizer_type == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=FLAGS.learning_rate,
                                  eps=1e-08)
        optimizer_D = optim.RMSprop(discriminator.parameters(),
                                    lr=FLAGS.learning_rate_d,
                                    eps=1e-08)
    elif FLAGS.optimizer_type == "YellowFin":
        raise NotImplementedError
    else:
        raise NotImplementedError

    # Build trainer.
    if FLAGS.evolution:
        trainer = ModelTrainer_ES(model, optimizer)
    else:
        trainer = ModelTrainer(model, optimizer)

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    if logfile_header:
        logfile_header.model_architecture = str(model)
    total_params = sum([
        reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()
    ])
    logger.Log("Total params: {}".format(total_params))
    if logfile_header:
        logfile_header.total_params = int(total_params)

    return model, optimizer, trainer, gold_tree, discriminator, optimizer_D, optimizer_tree
Пример #15
0
LAMBDA = 10  # Gradient penalty lambda hyperparameter
MAX_EPOCH = 100  # How many generator iterations to train for
D_G_INPUT_DIM = len(train_data.columns)
G_OUTPUT_DIM = len(train_data.columns)
D_OUTPUT_DIM = 1
CLAMP = 0.01

# read parameters of IDS
ids_model = Blackbox_IDS(D_G_INPUT_DIM, 2)
param = th.load('save_model/IDS.pth')
ids_model.load_state_dict(param)
#read model
generator = Generator(D_G_INPUT_DIM, G_OUTPUT_DIM)
discriminator = Discriminator(D_G_INPUT_DIM, D_OUTPUT_DIM)

optimizer_G = optim.RMSprop(generator.parameters(), lr=0.0001)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.0001)

batch_attack = create_batch2(raw_attack, BATCH_SIZE)
d_losses, g_losses = [], []
ids_model.eval()
generator.train()
discriminator.train()
cnt = -5
print("IDSGAN start training")
print("-" * 100)
for epoch in range(MAX_EPOCH):

    batch_normal = create_batch2(normal, BATCH_SIZE)
    run_g_loss = 0.
    run_d_loss = 0.
Пример #16
0
def main():
    # ------------------------
    # SETUP
    # ------------------------

    # Parse flags
    config = forge.config()
    if config.debug:
        config.num_workers = 0
        config.batch_size = 2

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setup checkpoint or resume
    logdir = osp.join(config.results_dir, config.run_name)
    logdir, resume_checkpoint = fet.init_checkpoint(
        logdir, config.data_config, config.model_config, config.resume)
    checkpoint_name = osp.join(logdir, 'model.ckpt')

    # Using GPU(S)?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fprint(f"Use GPU: {config.gpu}")
    if config.gpu and config.multi_gpu and torch.cuda.device_count() > 1:
        fprint(f"Using {torch.cuda.device_count()} GPUs!")
        config.num_workers = torch.cuda.device_count() * config.num_workers
    else:
        config.multi_gpu = False

    # Print flags
    # fet.print_flags()
    # TODO(martin) make this cleaner
    fprint(json.dumps(fet._flags.FLAGS.__flags, indent=4, sort_keys=True))

    # Setup TensorboardX SummaryWriter
    writer = SummaryWriter(logdir)

    # Load data
    train_loader, val_loader, test_loader = fet.load(config.data_config, config)
    num_elements = 3 * config.img_size**2  # Assume three input channels

    # Load model
    model = fet.load(config.model_config, config)
    fprint(model)
    if config.geco:
        # Goal is specified per pixel & channel so it doesn't need to
        # be changed for different resolutions etc.
        geco_goal = config.g_goal * num_elements
        # Scale step size to get similar update at different resolutions
        geco_lr = config.g_lr * (64**2 / config.img_size**2)
        geco = GECO(geco_goal, geco_lr, config.g_alpha, config.g_init,
                    config.g_min, config.g_speedup)
        beta = geco.beta
    else:
        beta = torch.tensor(config.beta)

    # Setup optimiser
    if config.optimiser == 'rmsprop':
        optimiser = optim.RMSprop(model.parameters(), config.learning_rate)
    elif config.optimiser == 'adam':
        optimiser = optim.Adam(model.parameters(), config.learning_rate)
    elif config.optimiser == 'sgd':
        optimiser = optim.SGD(model.parameters(), config.learning_rate, 0.9)

    # Try to restore model and optimiser from checkpoint
    iter_idx = 0
    if resume_checkpoint is not None:
        fprint(f"Restoring checkpoint from {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location='cpu')
        # Restore model & optimiser
        model_state_dict = checkpoint['model_state_dict']
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None)
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None)
        model.load_state_dict(model_state_dict)
        optimiser.load_state_dict(checkpoint['optimiser_state_dict'])
        # Restore GECO
        if config.geco and 'beta' in checkpoint:
            geco.beta = checkpoint['beta']
        if config.geco and 'err_ema' in checkpoint:
            geco.err_ema = checkpoint['err_ema']
        # Update starting iter
        iter_idx = checkpoint['iter_idx'] + 1
    fprint(f"Starting training at iter = {iter_idx}")

    # Push model to GPU(s)?
    if config.multi_gpu:
        fprint("Wrapping model in DataParallel.")
        model = nn.DataParallel(model)
    if config.gpu:
        fprint("Pushing model to GPU.")
        model = model.cuda()
        if config.geco:
            geco.to_cuda()

    # ------------------------
    # TRAINING
    # ------------------------

    model.train()
    timer = time.time()
    while iter_idx <= config.train_iter:
        for train_batch in train_loader:
            # Parse data
            train_input = train_batch['input']
            if config.gpu:
                train_input = train_input.cuda()

            # Forward propagation
            optimiser.zero_grad()
            output, losses, stats, att_stats, comp_stats = model(train_input)

            # Reconstruction error
            err = losses.err.mean(0)
            # KL divergences
            kl_m, kl_l = torch.tensor(0), torch.tensor(0)
            # -- KL stage 1
            if 'kl_m' in losses:
                kl_m = losses.kl_m.mean(0)
            elif 'kl_m_k' in losses:
                kl_m = torch.stack(losses.kl_m_k, dim=1).mean(dim=0).sum()
            # -- KL stage 2
            if 'kl_l' in losses:
                kl_l = losses.kl_l.mean(0)
            elif 'kl_l_k' in losses:
                kl_l = torch.stack(losses.kl_l_k, dim=1).mean(dim=0).sum()

            # Compute ELBO
            elbo = (err + kl_l + kl_m).detach()
            err_new = err.detach()
            kl_new = (kl_m + kl_l).detach()
            # Compute MSE / RMSE
            mse_batched = ((train_input-output)**2).mean((1, 2, 3)).detach()
            rmse_batched = mse_batched.sqrt()
            mse, rmse = mse_batched.mean(0), rmse_batched.mean(0)

            # Main objective
            if config.geco:
                loss = geco.loss(err, kl_l + kl_m)
                beta = geco.beta
            else:
                if config.beta_warmup:
                    # Increase beta linearly over 20% of training
                    beta = config.beta*iter_idx / (0.2*config.train_iter)
                    beta = torch.tensor(beta).clamp(0, config.beta)
                else:
                    beta = config.beta
                loss = err + beta*(kl_l + kl_m)

            # Backprop and optimise
            loss.backward()
            optimiser.step()

            # Heartbeat log
            if (iter_idx % config.report_loss_every == 0 or
                    float(elbo) > ELBO_DIV or config.debug):
                # Print output and write to file
                ps = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                ps += f' {config.run_name} | '
                ps += f'[{iter_idx}/{config.train_iter:.0e}]'
                ps += f' elb: {float(elbo):.0f} err: {float(err):.0f} '
                if 'kl_m' in losses or 'kl_m_k' in losses:
                    ps += f' klm: {float(kl_m):.1f}'
                if 'kl_l' in losses or 'kl_l_k' in losses:
                    ps += f' kll: {float(kl_l):.1f}'
                ps += f' bet: {float(beta):.1e}'
                s_per_b = (time.time()-timer)
                if not config.debug:
                    s_per_b /= config.report_loss_every
                timer = time.time()  # Reset timer
                ps += f' - {s_per_b:.2f} s/b'
                fprint(ps)

                # TensorBoard logging
                # -- Optimisation stats
                writer.add_scalar('optim/beta', beta, iter_idx)
                writer.add_scalar('optim/s_per_batch', s_per_b, iter_idx)
                if config.geco:
                    writer.add_scalar('optim/geco_err_ema',
                                      geco.err_ema, iter_idx)
                    writer.add_scalar('optim/geco_err_ema_element',
                                      geco.err_ema/num_elements, iter_idx)
                # -- Main loss terms
                writer.add_scalar('train/err', err, iter_idx)
                writer.add_scalar('train/err_element', err/num_elements, iter_idx)
                writer.add_scalar('train/kl_m', kl_m, iter_idx)
                writer.add_scalar('train/kl_l', kl_l, iter_idx)
                writer.add_scalar('train/elbo', elbo, iter_idx)
                writer.add_scalar('train/loss', loss, iter_idx)
                writer.add_scalar('train/mse', mse, iter_idx)
                writer.add_scalar('train/rmse', rmse, iter_idx)
                # -- Per step loss terms
                for key in ['kl_l_k', 'kl_m_k']:
                    if key not in losses: continue
                    for step, val in enumerate(losses[key]):
                        writer.add_scalar(f'train_steps/{key}{step}',
                                          val.mean(0), iter_idx)
                # -- Attention stats
                if config.log_distributions and att_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in att_stats: continue
                        for step, val in enumerate(att_stats[key]):
                            writer.add_histogram(f'att_{key}_{step}',
                                                 val, iter_idx)
                # -- Component stats
                if config.log_distributions and comp_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in comp_stats: continue
                        for step, val in enumerate(comp_stats[key]):
                            writer.add_histogram(f'comp_{key}_{step}',
                                                 val, iter_idx)

            # Save checkpoints
            ckpt_freq = config.train_iter / config.num_checkpoints
            if iter_idx % ckpt_freq == 0:
                ckpt_file = '{}-{}'.format(checkpoint_name, iter_idx)
                fprint(f"Saving model training checkpoint to: {ckpt_file}")
                if config.multi_gpu:
                    model_state_dict = model.module.state_dict()
                else:
                    model_state_dict = model.state_dict()
                ckpt_dict = {'iter_idx': iter_idx,
                             'model_state_dict': model_state_dict,
                             'optimiser_state_dict': optimiser.state_dict(),
                             'elbo': elbo}
                if config.geco:
                    ckpt_dict['beta'] = geco.beta
                    ckpt_dict['err_ema'] = geco.err_ema
                torch.save(ckpt_dict, ckpt_file)

            # Run validation and log images
            if (iter_idx % config.run_validation_every == 0 or
                    float(elbo) > ELBO_DIV):
                # Weight and gradient histograms
                if config.log_grads_and_weights:
                    for name, param in model.named_parameters():
                        writer.add_histogram(f'weights/{name}', param.data,
                                             iter_idx)
                        writer.add_histogram(f'grads/{name}', param.grad,
                                             iter_idx)
                # TensorboardX logging - images
                visualise_inference(model, train_batch, writer, 'train', 
                                    iter_idx)
                # Validation
                fprint("Running validation...")
                eval_model = model.module if config.multi_gpu else model
                evaluation(eval_model, val_loader, writer, config, iter_idx,
                           N_eval=config.N_eval)

            # Increment counter
            iter_idx += 1
            if iter_idx > config.train_iter:
                break

            # Exit if training has diverged
            if elbo.item() > ELBO_DIV:
                fprint(f"ELBO: {elbo.item()}")
                fprint(f"ELBO has exceeded {ELBO_DIV} - training has diverged.")
                sys.exit()

    # ------------------------
    # TESTING
    # ------------------------

    # Save final checkpoint
    ckpt_file = '{}-{}'.format(checkpoint_name, 'FINAL')
    fprint(f"Saving model training checkpoint to: {ckpt_file}")
    if config.multi_gpu:
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()
    ckpt_dict = {'iter_idx': iter_idx,
                 'model_state_dict': model_state_dict,
                 'optimiser_state_dict': optimiser.state_dict()}
    if config.geco:
        ckpt_dict['beta'] = geco.beta
        ckpt_dict['err_ema'] = geco.err_ema
    torch.save(ckpt_dict, ckpt_file)

    # Test evaluation
    fprint("STARTING TESTING...")
    eval_model = model.module if config.gpu and config.multi_gpu else model
    final_elbo = evaluation(
        eval_model, test_loader, None, config, iter_idx, N_eval=config.N_eval)
    fprint(f"TEST ELBO = {float(final_elbo)}")

    # FID computation
    try:
        fid_from_model(model, test_loader, img_dir=osp.join('/tmp', logdir))
    except NotImplementedError:
        fprint("Sampling not implemented for this model.")

    # Close writer
    writer.close()
Пример #17
0
        x = self.dropout1(x)
        x = F.relu(self.conv3(x))
        x = self.pool(self.conv4(x))
        x = self.dropout2(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.flat1(x))
        x = self.dropout1(x)
        x = self.flat2(x)
        return x


model = Net()
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-6)

# Train the network
n_epochs = 5
valid_loss_min = np.Inf

for epoch in range(1, n_epochs + 1):
    print('[INFO] Starting training...')

    train_loss = 0.0
    valid_loss = 0.0

    model.train()
    for data, target in train_loader:
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
Пример #18
0
    def train(self):
        """Training the BiGAN"""
        if self.args.data == 'mnist':
            img_channels = 1
            self.G = Generator_small(img_channels,
                                     self.args.latent_dim,
                                     use_tanh=self.args.normalize_data).to(
                                         self.device)
            self.E = Encoder_small(img_channels, self.args.latent_dim,
                                   self.args.use_relu_z,
                                   self.args.first_filter_size).to(self.device)
            self.D = Discriminator_small(img_channels, self.args.latent_dim,
                                         self.args.wasserstein).to(self.device)
        else:
            img_channels = 3
            self.G = Generator(img_channels,
                               self.args.latent_dim,
                               use_tanh=self.args.normalize_data).to(
                                   self.device)
            self.E = Encoder(img_channels, self.args.latent_dim,
                             self.args.use_relu_z,
                             self.args.first_filter_size).to(self.device)
            self.D = Discriminator(img_channels, self.args.latent_dim,
                                   self.args.wasserstein).to(self.device)

        if self.args.pretrained_path and os.path.exists(
                self.args.pretrained_path):
            ckpt = torch.load(self.args.pretrained_path)
            self.G.load_state_dict(ckpt['G'])
            self.E.load_state_dict(ckpt['E'])
            self.D.load_state_dict(ckpt['D'])
        else:
            self.G.apply(weights_init_normal)
            self.E.apply(weights_init_normal)
            self.D.apply(weights_init_normal)

        if self.args.freeze_GD:
            # Train the encoder only, with the generator & discriminator frozen.
            self.G.eval()
            self.D.eval()
            optimizer_d = None
            if self.args.wasserstein:
                optimizer_ge = optim.RMSprop(list(self.E.parameters()),
                                             lr=self.args.lr_rmsprop)
            else:
                optimizer_ge = optim.Adam(list(self.E.parameters()),
                                          lr=self.args.lr_adam,
                                          weight_decay=1e-6)
        else:
            if self.args.wasserstein:
                optimizer_ge = optim.RMSprop(list(self.G.parameters()) +
                                             list(self.E.parameters()),
                                             lr=self.args.lr_rmsprop)
                optimizer_d = optim.RMSprop(self.D.parameters(),
                                            lr=self.args.lr_rmsprop)
            else:
                optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                          list(self.E.parameters()),
                                          lr=self.args.lr_adam,
                                          weight_decay=1e-6)
                optimizer_d = optim.Adam(self.D.parameters(),
                                         lr=self.args.lr_adam,
                                         weight_decay=1e-6)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs):
            ge_losses = 0
            d_losses = 0
            for x, xi in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                if epoch < self.args.num_epochs:
                    noise1 = Variable(torch.Tensor(x.size()).normal_(
                        0, 0.1 * (self.args.num_epochs - epoch) /
                        self.args.num_epochs),
                                      requires_grad=False).to(self.device)
                    noise2 = Variable(torch.Tensor(x.size()).normal_(
                        0, 0.1 * (self.args.num_epochs - epoch) /
                        self.args.num_epochs),
                                      requires_grad=False).to(self.device)
                else:
                    # NOTE: added by BB: else the above reports error about std=0 in the last epoch
                    noise1, noise2 = 0, 0

                #Cleaning gradients.
                if optimizer_d:
                    optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                # BB's NOTE: x_true has values in [0, 1]
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                if self.args.wasserstein:
                    loss_d = -torch.mean(out_true) + torch.mean(out_fake)
                else:
                    loss_d = criterion(out_true, y_true) + criterion(
                        out_fake, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward()
                if optimizer_d:
                    optimizer_d.step()

                #Cleaning gradients.
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                if self.args.wasserstein:
                    loss_ge = -torch.mean(out_fake) + torch.mean(out_true)
                else:
                    loss_ge = criterion(out_fake, y_true) + criterion(
                        out_true, y_fake)

                if self.args.use_l2_loss:
                    loss_ge += self.get_latent_l2_loss()
                    loss_ge += self.get_image_l2_loss(x_true)

                loss_ge.backward()
                optimizer_ge.step()

                if self.args.wasserstein:
                    for p in self.D.parameters():
                        p.data.clamp_(-self.args.clamp, self.args.clamp)

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

                if USE_WANDB:
                    wandb.log({
                        'iter': epoch * len(self.train_loader) + xi,
                        'loss_ge': loss_ge.item(),
                        'loss_d': loss_d.item(),
                    })

            if epoch % 50 == 0:
                images = self.G(fixed_z).data
                vutils.save_image(images, './images/{}_fake.png'.format(epoch))
                images_lst = [
                    wandb.Image(image.cpu().numpy().transpose(1, 2, 0) * 255,
                                caption="Epoch {}, #{}".format(epoch, ii))
                    for ii, image in enumerate(images)
                ]
                wandb.log({"examples": images_lst})
                if self.args.save_path:
                    save_ckpt(
                        self,
                        self.args.save_path.replace(
                            '.pt', '_tmp_e{}.pt'.format(epoch)))
                else:
                    save_ckpt(
                        self, 'ckpt_epoch{}_tmp_e{}.pt'.format(
                            self.args.num_epochs, epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
Пример #19
0
            # stored in the device (mostly, if it's a GPU)
            thisArchit.to(device)

            #############
            # OPTIMIZER #
            #############

            if thisTrainer == 'ADAM':
                thisOptim = optim.Adam(thisArchit.parameters(),
                                       lr=learningRate,
                                       betas=(beta1, beta2))
            elif thisTrainer == 'SGD':
                thisOptim = optim.SGD(thisArchit.parameters(), lr=learningRate)
            elif thisTrainer == 'RMSprop':
                thisOptim = optim.RMSprop(thisArchit.parameters(),
                                          lr=learningRate,
                                          alpha=beta1)

            ########
            # LOSS #
            ########

            thisLossFunction = lossFunction

            #########
            # MODEL #
            #########

            Polynomial = model.Model(thisArchit, thisLossFunction, thisOptim,
                                     thisName, saveDir, order)
Пример #20
0
                            pin_memory=True)

    # logging training overview
    print('-----\n Start training:')
    print(
        f'epochs: {args.epochs} \t batch size: {args.batch_size} \t learning rate: {args.learning_rate} \t'
    )
    print(
        f'training size: {n_train} \t validation size: {n_val} \t checkpoints_dir: {args.checkpoints_dir} \t images downscale: {args.down_scale}'
    )
    print('-----')

    ## --- Set up training
    global_step = 0
    optimizer = optim.RMSprop(net.parameters(),
                              lr=args.learning_rate,
                              weight_decay=1e-8)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    ## --- Start training
    epoch_loss_list = []
    val_score_list = []
    num_batches_per_epoch = len(dataset) // args.batch_size
    for epoch in range(args.epochs):
        net.train()

        epoch_loss = 0
        for batch in train_loader:
Пример #21
0
 def default_optimizer(self):
     return optim.RMSprop(self.net.parameters(),
                          momentum=self.params.rmsprop_momentum,
                          alpha=self.params.rmsprop_alpha,
                          lr=self.params.rmsprop_lr)
def main():
    mode = "regular"
    num_envs = 16

    def make_env():
        def _thunk():
            env = MiniPacman(mode, 1000)
            return env

        return _thunk

    envs = [make_env() for i in range(num_envs)]
    envs = SubprocVecEnv(envs)

    state_shape = envs.observation_space.shape

    #a2c hyperparams:
    gamma = 0.99
    entropy_coef = 0.01
    value_loss_coef = 0.5
    max_grad_norm = 0.5
    num_steps = 5
    num_frames = int(10e3)

    #rmsprop hyperparams:
    lr = 7e-4
    eps = 1e-5
    alpha = 0.99

    #Init a2c and rmsprop
    actor_critic = ActorCritic(envs.observation_space.shape,
                               envs.action_space.n)
    optimizer = optim.RMSprop(actor_critic.parameters(),
                              lr,
                              eps=eps,
                              alpha=alpha)

    #if USE_CUDA:
    #    actor_critic = actor_critic.cuda()

    rollout = RolloutStorage(num_steps, num_envs, envs.observation_space.shape)
    #rollout.cuda()

    all_rewards = []
    all_losses = []

    state = envs.reset()
    state = torch.FloatTensor(np.float32(state))

    rollout.states[0].copy_(state)

    episode_rewards = torch.zeros(num_envs, 1)
    final_rewards = torch.zeros(num_envs, 1)

    for i_update in tqdm(range(num_frames)):

        for step in range(num_steps):
            action = actor_critic.act(autograd.Variable(state))

            next_state, reward, done, _ = envs.step(
                action.squeeze(1).cpu().data.numpy())

            reward = torch.FloatTensor(reward).unsqueeze(1)
            episode_rewards += reward
            masks = torch.FloatTensor(1 - np.array(done)).unsqueeze(1)
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            #if USE_CUDA:
            #    masks = masks.cuda()

            state = torch.FloatTensor(np.float32(next_state))
            rollout.insert(step, state, action.data, reward, masks)

        _, next_value = actor_critic(
            autograd.Variable(rollout.states[-1], volatile=True))
        next_value = next_value.data

        returns = rollout.compute_returns(next_value, gamma)

        logit, action_log_probs, values, entropy = actor_critic.evaluate_actions(
            autograd.Variable(rollout.states[:-1]).view(-1, *state_shape),
            autograd.Variable(rollout.actions).view(-1, 1))

        values = values.view(num_steps, num_envs, 1)
        action_log_probs = action_log_probs.view(num_steps, num_envs, 1)
        advantages = autograd.Variable(returns) - values

        value_loss = advantages.pow(2).mean()
        action_loss = -(autograd.Variable(advantages.data) *
                        action_log_probs).mean()

        optimizer.zero_grad()
        loss = value_loss * value_loss_coef + action_loss - entropy * entropy_coef
        loss.backward()
        nn.utils.clip_grad_norm(actor_critic.parameters(), max_grad_norm)
        optimizer.step()

        if i_update % num_frames == 0:
            all_rewards.append(final_rewards.mean())
            all_losses.append(loss.item())

            #clear_output(True)
            plt.figure(figsize=(20, 5))
            plt.subplot(131)
            plt.title('epoch %s. reward: %s' %
                      (i_update, np.mean(all_rewards[-10:])))
            plt.plot(all_rewards)
            plt.subplot(132)
            plt.title('loss %s' % all_losses[-1])
            plt.plot(all_losses)
            plt.show()

        rollout.after_update()

    torch.save(actor_critic.state_dict(), "actor_critic_" + mode)

    import time

    def displayImage(image, step, reward):
        #clear_output(True)
        s = "step: " + str(step) + " reward: " + str(reward)
        plt.figure(figsize=(10, 3))
        plt.title(s)
        plt.imshow(image)
        plt.show()
        time.sleep(0.1)

    env = MiniPacman(mode, 1000)

    done = False
    state = env.reset()
    total_reward = 0
    step = 1

    while not done:
        current_state = torch.FloatTensor(state).unsqueeze(0)
        #if USE_CUDA:
        #    current_state = current_state.cuda()

        action = actor_critic.act(autograd.Variable(current_state))

        next_state, reward, done, _ = env.step(action.data[0, 0])
        total_reward += reward
        state = next_state

        image = torch.FloatTensor(state).permute(1, 2, 0).cpu().numpy()
        displayImage(image, step, total_reward)
        step += 1
Пример #23
0
def main(args):
    torch.manual_seed(4)
    torch.autograd.set_detect_anomaly(True)

    name = [args.name]
    if WGAN:
        name.append('wgan')
    if GRU:
        name.append('gru')
    if GCNN:
        name.append('gcnn')

    # name.append('num_iters_{}'.format(args.num_iters))
    # name.append('num_critic_{}'.format(args.num_critic))
    args.name = '_'.join(name)

    args.model_path = args.dir_path + '/models/'
    args.losses_path = args.dir_path + '/losses/'
    args.args_path = args.dir_path + '/args/'
    args.figs_path = args.dir_path + '/figs/'
    args.dataset_path = args.dir_path + '/dataset/'
    args.err_path = args.dir_path + '/err/'

    if(not exists(args.model_path)):
        mkdir(args.model_path)
    if(not exists(args.losses_path)):
        mkdir(args.losses_path)
    if(not exists(args.args_path)):
        mkdir(args.args_path)
    if(not exists(args.figs_path)):
        mkdir(args.figs_path)
    if(not exists(args.err_path)):
        mkdir(args.err_path)
    if(not exists(args.dataset_path)):
        mkdir(args.dataset_path)
        try:
            # python2
            file_tmp = urllib.urlretrieve(url, filename=None)[0]
        except:
            # python3
            file_tmp = urllib.request.urlretrieve(url, filename=args.dataset)[0]

    prev_models = [f[:-4] for f in listdir(args.args_path)]  # removing .txt

    if (args.name in prev_models):
        print("name already used")
        # if(not args.load_model):
        #    sys.exit()
    else:
        mkdir(args.losses_path + args.name)
        mkdir(args.model_path + args.name)
        mkdir(args.figs_path + args.name)

    if(not args.load_model):
        f = open(args.args_path + args.name + ".txt", "w+")
        f.write(str(vars(args)))
        f.close()
    else:
        f = open(args.args_path + args.name + ".txt", "r")
        temp = args.start_epoch
        args = eval(f.read())
        f.close()
        args.load_model = True
        args.start_epoch = temp
        # return args2

    def pf(data):
        return data.y == args.num

    pre_filter = pf if args.num != -1 else None

    print("loading")

    # Change to True !!
    X = SuperpixelsDataset(args.dataset_path, args.num_hits, train=TRAIN, num=NUM, device=device)
    tgX = MNISTSuperpixels(args.dir_path, train=TRAIN, pre_transform=T.Cartesian(), pre_filter=pre_filter)

    X_loaded = DataLoader(X, shuffle=True, batch_size=args.batch_size, pin_memory=True)
    tgX_loaded = tgDataLoader(tgX, shuffle=True, batch_size=args.batch_size)

    print("loaded")

    if(args.load_model):
        G = torch.load(args.model_path + args.name + "/G_" + str(args.start_epoch) + ".pt")
        D = torch.load(args.model_path + args.name + "/D_" + str(args.start_epoch) + ".pt")
    else:
        G = Graph_Generator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.gru_hidden_size, args.gru_num_layers, args.num_iters, args.num_hits, args.dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, int_diffs=INT_DIFFS, gru=GRU, device=device).to(device)
        if(GCNN):
            D = MoNet(kernel_size=args.kernel_size, dropout=args.dropout, device=device).to(device)
            # D = Gaussian_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.gru_hidden_size, args.gru_num_layers, args.num_iters, args.num_hits, args.dropout, args.leaky_relu_alpha, kernel_size=args.kernel_size, hidden_node_size=args.hidden_node_size, int_diffs=INT_DIFFS, gru=GRU).to(device)
        else:
            D = Graph_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.gru_hidden_size, args.gru_num_layers, args.num_iters, args.num_hits, args.dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, int_diffs=INT_DIFFS, gru=GRU, device=device).to(device)

    print("Models loaded")

    if(WGAN):
        G_optimizer = optim.RMSprop(G.parameters(), lr=args.lr_gen)
        D_optimizer = optim.RMSprop(D.parameters(), lr=args.lr_disc)
    else:
        G_optimizer = optim.Adam(G.parameters(), lr=args.lr_gen, weight_decay=5e-4)
        D_optimizer = optim.Adam(D.parameters(), lr=args.lr_disc, weight_decay=5e-4)

    print("optimizers loaded")

    normal_dist = Normal(0, 0.2)

    def wasserstein_loss(y_out, y_true):
        return -torch.mean(y_out * y_true)

    if(WGAN):
        criterion = wasserstein_loss
    else:
        if(LSGAN):
            criterion = torch.nn.MSELoss()
        else:
            criterion = torch.nn.BCELoss()

    # print(criterion(torch.tensor([1.0]),torch.tensor([-1.0])))

    def gen(num_samples, noise=0):
        if(noise == 0):
            noise = normal_dist.sample((num_samples, args.num_hits, args.hidden_node_size)).to(device)

        return G(noise)

    # transform my format to torch_geometric's
    def tg_transform(X):
        batch_size = X.size(0)
        cutoff = 0.32178  # found empirically to match closest to Superpixels

        pos = X[:, :, :2]

        x1 = pos.repeat(1, 1, 75).reshape(batch_size, 75*75, 2)
        x2 = pos.repeat(1, 75, 1)

        diff_norms = torch.norm(x2 - x1 + 1e-12, dim=2)

        diff = x2-x1
        diff = diff[diff_norms < cutoff]

        norms = diff_norms.reshape(batch_size, 75, 75)
        neighborhood = torch.nonzero(norms < cutoff, as_tuple=False)
        edge_attr = diff[neighborhood[:, 1] != neighborhood[:, 2]]

        neighborhood = neighborhood[neighborhood[:, 1] != neighborhood[:, 2]]  # remove self-loops
        unique, counts = torch.unique(neighborhood[:, 0], return_counts=True)
        edge_slices = torch.cat((torch.tensor([0]).to(device), counts.cumsum(0)))
        edge_index = neighborhood[:, 1:].transpose(0, 1)

        # normalizing edge attributes
        edge_attr_list = list()
        for i in range(batch_size):
            start_index = edge_slices[i]
            end_index = edge_slices[i+1]
            temp = diff[start_index:end_index]
            max = torch.max(temp)
            temp = temp/(2*max + 1e-12) + 0.5
            edge_attr_list.append(temp)

        edge_attr = torch.cat(edge_attr_list)

        x = X[:, :, 2].reshape(batch_size*75, 1)+0.5
        pos = 27*pos.reshape(batch_size*75, 2)+13.5

        zeros = torch.zeros(batch_size*75, dtype=int).to(device)
        zeros[torch.arange(batch_size)*75] = 1
        batch = torch.cumsum(zeros, 0)-1

        return Batch(batch=batch, x=x, edge_index=edge_index.contiguous(), edge_attr=edge_attr, y=None, pos=pos)

    def draw_graph(graph, node_r, im_px):
        imd = im_px + node_r
        img = np.zeros((imd, imd), dtype=np.float)

        circles = []
        for node in graph:
            circles.append((draw.circle_perimeter(int(node[1]), int(node[0]), node_r), draw.disk((int(node[1]), int(node[0])), node_r), node[2]))

        for circle in circles:
            img[circle[1]] = circle[2]

        return img

    def save_sample_outputs(name, epoch, dlosses, glosses):
        print("drawing figs")
        fig = plt.figure(figsize=(10, 10))

        num_ims = 100
        node_r = 30
        im_px = 1000

        gen_out = gen(args.batch_size).cpu().detach().numpy()

        for i in range(int(num_ims/args.batch_size)):
            gen_out = np.concatenate((gen_out, gen(args.batch_size).cpu().detach().numpy()), 0)

        gen_out = gen_out[:num_ims]

        gen_out[gen_out > 0.47] = 0.47
        gen_out[gen_out < -0.5] = -0.5

        gen_out = gen_out*[im_px, im_px, 1] + [(im_px+node_r)/2, (im_px+node_r)/2, 0.55]

        for i in range(1, num_ims+1):
            fig.add_subplot(10, 10, i)
            im_disp = draw_graph(gen_out[i-1], node_r, im_px)
            plt.imshow(im_disp, cmap=cm.gray_r, interpolation='nearest')
            plt.axis('off')

        plt.savefig(args.figs_path + args.name + "/" + str(epoch) + ".png")
        plt.close()

        plt.figure()
        plt.plot(dlosses, label='Discriminitive loss')
        plt.plot(glosses, label='Generative loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(args.losses_path + args.name + "/" + str(epoch) + ".png")
        plt.close()

        print("saved figs")

    def save_models(name, epoch):
        torch.save(G, args.model_path + args.name + "/G_" + str(epoch) + ".pt")
        torch.save(D, args.model_path + args.name + "/D_" + str(epoch) + ".pt")

    # from https://github.com/EmilienDupont/wgan-gp
    def gradient_penalty(real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1)
        alpha = alpha.expand_as(real_data).to(device)
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True).to(device)

        del alpha
        torch.cuda.empty_cache()

        # Calculate probability of interpolated examples
        prob_interpolated = D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).to(device), create_graph=True, retain_graph=True, allow_unused=True)[0].to(device)

        gradients = gradients.contiguous()

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return args.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def train_D(data):
        D.train()
        D_optimizer.zero_grad()

        run_batch_size = data.shape[0] if not GCNN else data.y.shape[0]

        if(not WGAN):
            Y_real = torch.ones(run_batch_size, 1).to(device)
            Y_fake = torch.zeros(run_batch_size, 1).to(device)

        try:
            D_real_output = D(data)
            gen_ims = gen(run_batch_size)

            tg_gen_ims = tg_transform(gen_ims)

            use_gen_ims = tg_gen_ims if GCNN else gen_ims

            D_fake_output = D(use_gen_ims)

            if(WGAN):
                D_loss = D_fake_output.mean() - D_real_output.mean() + gradient_penalty(data, use_gen_ims)
            else:
                D_real_loss = criterion(D_real_output, Y_real)
                D_fake_loss = criterion(D_fake_output, Y_fake)

                D_loss = D_real_loss + D_fake_loss

            D_loss.backward()
            D_optimizer.step()

        except:
            print("Generated Images")
            print(gen_ims)

            print("Transformed Images")
            print(tg_gen_ims)

            print("Discriminator Output")
            print(D_fake_output)

            torch.save(gen_ims, args.err_path + args.name + "_gen_ims.pt")
            torch.save(tg_gen_ims.x, args.err_path + args.name + "_x.pt")
            torch.save(tg_gen_ims.pos, args.err_path + args.name + "_pos.pt")
            torch.save(tg_gen_ims.edge_index, args.err_path + args.name + "_edge_index.pt")
            return

        return D_loss.item()

    def train_G():
        G.train()
        G_optimizer.zero_grad()

        if(not WGAN):
            Y_real = torch.ones(args.batch_size, 1).to(device)

        gen_ims = gen(args.batch_size)

        if(GCNN):
            gen_ims = tg_transform(gen_ims)

        D_fake_output = D(gen_ims)

        if(WGAN):
            G_loss = -D_fake_output.mean()
        else:
            G_loss = criterion(D_fake_output, Y_real)

        G_loss.backward()
        G_optimizer.step()

        return G_loss.item()

    D_losses = []
    G_losses = []

    # save_models(name, 0)

    # save_sample_outputs(args.name, 0, D_losses, G_losses)

    # @profile
    def train():
        for i in range(args.start_epoch, args.num_epochs):
            print("Epoch %d %s" % ((i+1), args.name))
            D_loss = 0
            G_loss = 0
            loader = tgX_loaded if GCNN else X_loaded
            for batch_ndx, data in tqdm(enumerate(loader), total=len(loader)):
                if(batch_ndx > 0 and batch_ndx % (args.num_critic+1) == 0):
                    G_loss += train_G()
                else:
                    D_loss += train_D(data.to(device)) if GCNN else train_D(data[0].to(device))

            D_losses.append(D_loss/len(X_loaded)/2)
            G_losses.append(G_loss/len(X_loaded))

            if((i+1) % 5 == 0):
                save_sample_outputs(args.name, i+1, D_losses, G_losses)

            if((i+1) % 5 == 0):
                save_models(args.name, i+1)

    train()
Пример #24
0
    batches_per_epoch = len(train_loader)

    ########### Creare criterion and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    lr_fn = learning_rate_with_decay(args.batch_size,
                                     batch_denom=128,
                                     batches_per_epoch=batches_per_epoch,
                                     boundary_epochs=[60, 100, 140],
                                     decay_rates=[1, 0.1, 0.01, 0.001],
                                     lr0=args.lr)

    optimizer = optim.RMSprop([
        {
            "params": model.parameters(),
            'lr': args.lr
        },
    ],
                              lr=args.lr)

    ########### Train the model
    nsolvers = len(train_solvers)
    best_acc = [0] * nsolvers

    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs_nn * batches_per_epoch):
Пример #25
0
def main():
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    os.environ['OMP_NUM_THREADS'] = '1'

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = [
        make_env(args.env_name, args.seed, i, args.log_dir)
        for i in range(args.num_processes)
    ]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    if len(envs.observation_space.shape) == 3:
        actor_critic = CNNPolicy(obs_shape[0], envs.action_space,
                                 args.recurrent_policy,
                                 args.autoencoder_uncertainty)
    else:
        assert not args.recurrent_policy, \
            "Recurrent policy is not implemented for the MLP controller"
        actor_critic = MLPPolicy(obs_shape[0], envs.action_space)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        optimizer = optim.RMSprop(actor_critic.parameters(),
                                  args.lr,
                                  eps=args.eps,
                                  alpha=args.alpha)
    elif args.algo == 'ppo':
        optimizer = optim.Adam(actor_critic.parameters(),
                               args.lr,
                               eps=args.eps)
    elif args.algo == 'acktr':
        optimizer = KFACOptimizer(actor_critic)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                              envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    # an observation carries an consecutive frames stacked to make an input
    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            value, action, action_log_prob, states = actor_critic.act(
                Variable(rollouts.observations[step], volatile=True),
                Variable(rollouts.states[step], volatile=True),
                Variable(rollouts.masks[step], volatile=True))
            cpu_actions = action.data.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(step, current_obs, states.data, action.data,
                            action_log_prob.data, value.data, reward, masks)

        next_value = actor_critic(
            Variable(rollouts.observations[-1], volatile=True),
            Variable(rollouts.states[-1], volatile=True),
            Variable(rollouts.masks[-1], volatile=True))[0].data

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        if args.algo in ['a2c', 'acktr']:

            if args.autoencoder_uncertainty:
                values, action_log_probs, dist_entropy, states, u_loss = actor_critic.evaluate_actions(
                    Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
                    Variable(rollouts.states[0].view(-1,
                                                     actor_critic.state_size)),
                    Variable(rollouts.masks[:-1].view(-1, 1)),
                    Variable(rollouts.actions.view(-1, action_shape)),
                    next_state_target=Variable(rollouts.observations[1:].view(
                        -1, *obs_shape)),
                    action_space=envs.action_space.n)
            else:
                values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
                    Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
                    Variable(rollouts.states[0].view(-1,
                                                     actor_critic.state_size)),
                    Variable(rollouts.masks[:-1].view(-1, 1)),
                    Variable(rollouts.actions.view(-1, action_shape)))

            values = values.view(args.num_steps, args.num_processes, 1)
            action_log_probs = action_log_probs.view(args.num_steps,
                                                     args.num_processes, 1)

            advantages = Variable(rollouts.returns[:-1]) - values
            value_loss = advantages.pow(2).mean()

            action_loss = -(Variable(advantages.data) *
                            action_log_probs).mean()

            if args.algo == 'acktr' and optimizer.steps % optimizer.Ts == 0:
                # Sampled fisher, see Martens 2014
                actor_critic.zero_grad()
                pg_fisher_loss = -action_log_probs.mean()

                value_noise = Variable(torch.randn(values.size()))
                if args.cuda:
                    value_noise = value_noise.cuda()

                sample_values = values + value_noise
                vf_fisher_loss = -(values -
                                   Variable(sample_values.data)).pow(2).mean()

                fisher_loss = pg_fisher_loss + vf_fisher_loss
                optimizer.acc_stats = True
                fisher_loss.backward(retain_graph=True)
                optimizer.acc_stats = False

            optimizer.zero_grad()
            (u_loss[0] + value_loss * args.value_loss_coef + action_loss -
             dist_entropy * args.entropy_coef).backward()

            if args.algo == 'a2c':
                nn.utils.clip_grad_norm(actor_critic.parameters(),
                                        args.max_grad_norm)

            optimizer.step()
        elif args.algo == 'ppo':
            advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-5)

            for e in range(args.ppo_epoch):
                if args.recurrent_policy:
                    data_generator = rollouts.recurrent_generator(
                        advantages, args.num_mini_batch)
                else:
                    data_generator = rollouts.feed_forward_generator(
                        advantages, args.num_mini_batch)

                for sample in data_generator:
                    observations_batch, states_batch, actions_batch, \
                       return_batch, masks_batch, old_action_log_probs_batch, \
                            adv_targ = sample

                    # Reshape to do in a single forward pass for all steps
                    assert args.autoencoder_uncertainty == False
                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
                        Variable(observations_batch), Variable(states_batch),
                        Variable(masks_batch), Variable(actions_batch))

                    adv_targ = Variable(adv_targ)
                    ratio = torch.exp(action_log_probs -
                                      Variable(old_action_log_probs_batch))
                    surr1 = ratio * adv_targ
                    surr2 = torch.clamp(ratio, 1.0 - args.clip_param,
                                        1.0 + args.clip_param) * adv_targ
                    action_loss = -torch.min(
                        surr1,
                        surr2).mean()  # PPO's pessimistic surrogate (L^CLIP)

                    value_loss = (Variable(return_batch) -
                                  values).pow(2).mean()

                    optimizer.zero_grad()
                    (value_loss + action_loss -
                     dist_entropy * args.entropy_coef).backward()
                    nn.utils.clip_grad_norm(actor_critic.parameters(),
                                            args.max_grad_norm)
                    optimizer.step()

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs, 'ob_rms') and envs.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(),
                        dist_entropy.data[0], value_loss.data[0],
                        action_loss.data[0]))
            if args.autoencoder_uncertainty:
                print(", autoencoder loss {:.5f}".format(u_loss.data[0]))
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo)
            except IOError:
                pass
Пример #26
0
def init_optimizer(network):
    if opt.optimizer == 'adam':
        return optim.Adam(network.parameters(), lr=0.0002, betas=(0.5, 0.999))
    else:
        return optim.RMSprop(network.parameters(), lr=0.00005)
Пример #27
0
def analyze_grads_over_time(config, pretrain_model=False):
    device = torch.device(config.device)
    config.input_length = 150

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    total_norms = []

    for m in ["RNN", "LSTM"]:

        # pretrain model
        if pretrain_model:
            model = train(config)
        else:
            # Initialize params for models
            seq_length = config.input_length
            input_dim = config.input_dim
            num_hidden = config.num_hidden
            num_classes = config.num_classes

            # Initialize the model that we are going to use
            if m == 'RNN':
                model = VanillaRNN(seq_length, input_dim, num_hidden,
                                   num_classes, device)
            else:
                model = LSTM(seq_length, input_dim, num_hidden, num_classes,
                             device)

            model.to(device)

        # Initialize the dataset and data loader (note the +1)
        dataset = PalindromeDataset(config.input_length + 1)
        # data_loader = DataLoader(dataset, batch_size=1, num_workers=1)
        data_loader = DataLoader(dataset,
                                 batch_size=config.batch_size,
                                 num_workers=1)

        # Setup the loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate)

        # Get single batch from dataloader
        batch_inputs, batch_targets, = next(iter(data_loader))

        # convert to one-hot
        batch_inputs = torch.scatter(
            torch.zeros(*batch_inputs.size(), config.num_classes), 2,
            batch_inputs[..., None].to(torch.int64), 1).to(device)
        batch_targets = batch_targets.to(device)

        train_output = model.analyze_hs_gradients(batch_inputs)
        loss = criterion(train_output, batch_targets)

        optimizer.zero_grad()
        loss.backward()

        gradient_norms = []
        for i, (t, h) in enumerate(reversed(model.h_states)):
            _grad = h.grad  # (batch_size x hidden_dim)
            average_grads = torch.mean(
                _grad, dim=0
            )  # Calculate average gradient to get more stable estimate
            grad_l2_norm = average_grads.norm(2).item()
            gradient_norms.append(grad_l2_norm)

        print(len(gradient_norms))
        total_norms.append(gradient_norms)

    time_steps = np.arange(150)
    print(time_steps)

    fig = plt.figure(figsize=(15, 10), dpi=150)
    # fig.suptitle('L2-norm of Gradients across Time Steps (LSTM $b_f = 2$)', fontsize=32)
    fig.suptitle('L2-norm of Gradients across Time Steps', fontsize=36)
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(total_norms[0], linewidth=2, color="tomato", label="RNN")
    ax.plot(total_norms[1], linewidth=2, color="darkblue", label="LSTM")
    ax.tick_params(labelsize=16)
    ax.set_xticks(time_steps[::10])
    ax.set_xticklabels(time_steps[::10])

    ax.set_xlabel('Backpropagation Step', fontsize=24)
    ax.set_ylabel('Gradient Norm (L2)', fontsize=24)
    ax.legend(prop={'size': 16})

    if not os.path.exists('part1/figures/'):
        os.makedirs('part1/figures/')

    plt.savefig("part1/figures/Analyze_gradients_pt_{}.png".format(
        str(pretrain_model)))
    # plt.savefig("part1/figures/Analyze_gradients_pt_{}_bias_2.png".format(str(pretrain_model)))
    plt.show()
Пример #28
0
def train(args, t_args, device):
    env, input_shape = make_env(grid_size=t_args.grid_size,
                                sparse=t_args.sparse,
                                random_seed=t_args.seed)
    env = env[0]
    num_actions = env.action_space.n
    steps = 0

    log_dir = 'hrl_train_logs/'

    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    writer = SummaryWriter(log_dir=log_dir)

    f_net = FuNet(input_shape, args['d'], args['len_hist'], args['eps'],
                  args['k'], num_actions, args['num_worker'], device)

    optimizer = optim.RMSprop(f_net.parameters(), lr=1e-3)
    goal_history, s_Mt_hist, ep_binary = f_net.agent_model_init()

    for ep_num in count():
        state = env.reset()
        env.rw_dirts = env.dirts

        goal_history = [g.detach() for g in goal_history]

        mini_db = {
            'log_probs': [],
            'values_manager': [],
            'values_worker': [],
            'rewards': [],
            'intrinsic_rewards': [],
            'goal_errors': [],
            'masks': [],
            'entropy': 0
        }

        for i in count():

            state = torch.from_numpy(state.reshape(1, -1)).float().to(device)
            action_probs, v_Mt, v_Wt, goal_history, s_Mt_hist = f_net(
                state, goal_history, s_Mt_hist)
            a_t, log_p, etr = take_action(action_probs)
            next_state, reward, done, ep_info = env.step(a_t.item())

            ep = torch.FloatTensor([1.0 - done]).unsqueeze(-1).to(device)

            ep_binary.pop(0)
            ep_binary.append(ep)

            mini_db['entropy'] += etr.mean()

            mini_db['log_probs'].append(log_p.unsqueeze(-1))
            mini_db['values_manager'].append(v_Mt)
            mini_db['values_worker'].append(v_Wt)
            mini_db['rewards'].append(
                torch.tensor([[reward]], dtype=torch.float, device=device))
            mini_db['masks'].append(
                torch.tensor([[1 - done]], dtype=torch.float, device=device))
            mini_db['intrinsic_rewards'].append(
                f_net.int_reward(goal_history, s_Mt_hist,
                                 ep_binary).unsqueeze(-1))
            mini_db['goal_errors'].append(
                f_net.del_g_theta(goal_history, s_Mt_hist, ep_binary))

            state = next_state
            steps += 1

            if done:
                writer.add_scalars('episode/reward',
                                   {'reward': ep_info['ep_rewards']}, ep_num)
                writer.add_scalars('episode/length',
                                   {'length': ep_info['ep_len']}, ep_num)
                break

        next_state = torch.from_numpy(next_state.reshape(
            1, -1)).float().to(device)
        _, v_Mtp1, v_Wtp1, _, _ = f_net(next_state, goal_history, s_Mt_hist)

        ret_m = compute_returns(v_Mtp1, mini_db['rewards'], mini_db['masks'],
                                args['gamma_m'])
        ret_w = compute_returns(v_Wtp1, mini_db['rewards'], mini_db['masks'],
                                args['gamma_w'])

        log_probs = torch.cat(mini_db['log_probs'])
        ret_m = torch.cat(ret_m).detach()
        ret_w = torch.cat(ret_w).detach()
        intrinsic_rewards = torch.cat(mini_db['intrinsic_rewards'])
        goal_errors = torch.cat(mini_db['goal_errors'])

        value_m = torch.cat(mini_db['values_manager'])
        value_w = torch.cat(mini_db['values_worker'])

        advantage_m = ret_m - value_m
        advantage_w = (ret_w + args['alpha'] * intrinsic_rewards) - value_w

        loss_manager = -1 * (goal_errors * advantage_m.detach()).mean()
        loss_worker = -1 * (log_probs * advantage_w.detach()).mean()

        value_m_loss = 0.5 * advantage_m.pow(2).mean()
        value_w_loss = 0.5 * advantage_w.pow(2).mean()

        loss = loss_worker + loss_manager + value_w_loss + value_m_loss - (
            args['entropy_coef'] * mini_db['entropy'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalars('loss/total_loss', {'total_loss': loss}, ep_num)
        writer.add_scalars('loss/manager_loss', {'manager_loss': loss_manager},
                           ep_num)
        writer.add_scalars('loss/worker_loss', {'worker_loss': loss_worker},
                           ep_num)
        writer.add_scalars('loss/worker_value_fn_loss',
                           {'worker_value_func_loss': value_w_loss}, ep_num)
        writer.add_scalars('loss/manager_value_fn_loss',
                           {'man_value_func_loss': value_m_loss}, ep_num)

        if ep_num % 1000 == 0:
            torch.save(
                {
                    'model': f_net.state_dict(),
                    'args': args,
                    'goal': goal_history,
                    'man_state': s_Mt_hist
                }, 'saved_model/fnet_ckpt.pt')
Пример #29
0
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
UPDATE_TARGET_Q_FREQ = 3
TRAIN_FREQ = 5
num_episodes = 20

model = DQN()
q_ast = deepcopy(model)

if use_cuda:
    model.cuda()

optimizer = optim.RMSprop(model.parameters())
memory = ReplayMemory(10000)

steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        return model(Variable(
            state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1)
    else:
Пример #30
0
def train(args):
    use_gpu = torch.cuda.is_available()
    num_gpu = list(range(torch.cuda.device_count()))
    assert use_gpu, "Please use gpus."

    logger = get_logger(name=args.shortname)
    display_args(args, logger)

    # create dir for saving
    args.saverootpath = osp.abspath(args.saverootpath)
    savepath = osp.join(args.saverootpath, args.run_name)
    if not osp.exists(savepath):
        os.makedirs(savepath)

    train_file = os.path.join(args.image_sets,
                              "{}.txt".format(args.train_dataset))
    n_features = 35 if args.no_reflex else 36
    if args.pixor_fusion:
        if args.e2e:
            train_data = KittiDataset_Fusion_stereo(
                txt_file=train_file,
                flip_rate=args.flip_rate,
                lidar_dir=args.eval_lidar_dir,
                label_dir=args.eval_label_dir,
                calib_dir=args.eval_calib_dir,
                image_dir=args.eval_image_dir,
                root_dir=args.root_dir,
                only_feature=args.no_cal_loss,
                split=args.split,
                image_downscale=args.image_downscale,
                crop_height=args.crop_height,
                random_shift_scale=args.random_shift_scale)
        else:
            train_data = KittiDataset_Fusion(
                txt_file=train_file,
                flip_rate=args.flip_rate,
                lidar_dir=args.train_lidar_dir,
                label_dir=args.train_label_dir,
                calib_dir=args.train_calib_dir,
                n_features=n_features,
                random_shift_scale=args.random_shift_scale,
                root_dir=args.root_dir,
                image_downscale=args.image_downscale)

    else:
        train_data = KittiDataset(txt_file=train_file,
                                  flip_rate=args.flip_rate,
                                  lidar_dir=args.train_lidar_dir,
                                  label_dir=args.train_label_dir,
                                  calib_dir=args.train_calib_dir,
                                  image_dir=args.train_image_dir,
                                  n_features=n_features,
                                  random_shift_scale=args.random_shift_scale,
                                  root_dir=args.root_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=8)

    eval_data, eval_loader = get_eval_dataset(args)

    if args.pixor_fusion:
        pixor = PixorNet_Fusion(n_features,
                                groupnorm=args.groupnorm,
                                resnet_type=args.resnet_type,
                                image_downscale=args.image_downscale,
                                resnet_chls=args.resnet_chls)
    else:
        pixor = PixorNet(n_features, groupnorm=args.groupnorm)

    ts = time.time()
    pixor = pixor.cuda()
    pixor = nn.DataParallel(pixor, device_ids=num_gpu)

    class_criterion = nn.BCELoss(reduction='none')
    reg_criterion = nn.SmoothL1Loss(reduction='none')

    if args.opt_method == 'RMSprop':
        optimizer = optim.RMSprop(pixor.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    else:
        raise NotImplementedError()

    depth_model = PSMNet(maxdepth=80, maxdisp=192, down=args.depth_down)
    depth_model = nn.DataParallel(depth_model).cuda()
    # torch.backends.cudnn.benchmark = True
    depth_optimizer = optim.Adam(depth_model.parameters(),
                                 lr=args.depth_lr,
                                 betas=(0.9, 0.999))
    grid_3D_extended = get_3D_global_grid_extended(700, 800, 35).cuda().float()

    if args.depth_pretrain:
        if os.path.isfile(args.depth_pretrain):
            logger.info("=> loading depth pretrain '{}'".format(
                args.depth_pretrain))
            checkpoint = torch.load(args.depth_pretrain)
            depth_model.load_state_dict(checkpoint['state_dict'])
            depth_optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            logger.info('[Attention]: Do not find checkpoint {}'.format(
                args.depth_pretrain))

    depth_scheduler = MultiStepLR(depth_optimizer,
                                  milestones=args.depth_lr_stepsize,
                                  gamma=args.depth_lr_gamma)

    if args.pixor_pretrain:
        if os.path.isfile(args.pixor_pretrain):
            logger.info("=> loading depth pretrain '{}'".format(
                args.pixor_pretrain))
            checkpoint = torch.load(args.pixor_pretrain)
            pixor.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            optimizer.param_groups[0]['lr'] *= 10

        else:
            logger.info('[Attention]: Do not find checkpoint {}'.format(
                args.pixor_pretrain))

    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.lr_milestones,
                                         gamma=args.gamma)

    if args.resume:
        logger.info("Resuming...")
        checkpoint_path = osp.join(savepath, args.checkpoint)
        if os.path.isfile(checkpoint_path):
            logger.info("Loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            pixor.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            depth_model.load_state_dict(checkpoint['depth_state_dict'])
            depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
            depth_scheduler.load_state_dict(checkpoint['depth_scheduler'])
            start_epoch = checkpoint['epoch'] + 1
            logger.info(
                "Resumed successfully from epoch {}.".format(start_epoch))
        else:
            logger.warning("Model {} not found. "
                           "Train from scratch".format(checkpoint_path))
            start_epoch = 0
    else:
        start_epoch = 0

    class_criterion = class_criterion.cuda()
    reg_criterion = reg_criterion.cuda()

    processes = []
    last_eval_epoches = []
    for epoch in range(start_epoch, args.epochs):
        pixor.train()
        depth_model.train()
        scheduler.step()
        depth_scheduler.step()
        ts = time.time()
        logger.info("Start epoch {}, depth lr {:.6f} pixor lr {:.7f}".format(
            epoch, depth_optimizer.param_groups[0]['lr'],
            optimizer.param_groups[0]['lr']))

        avg_class_loss = AverageMeter()
        avg_reg_loss = AverageMeter()
        avg_total_loss = AverageMeter()

        train_metric = utils_func.Metric()

        for iteration, batch in enumerate(train_loader):

            if args.pixor_fusion:
                if not args.e2e:
                    inputs = batch['X'].cuda()
                else:
                    imgL = batch['imgL'].cuda()
                    imgR = batch['imgR'].cuda()
                    f = batch['f']
                    depth_map = batch['depth_map'].cuda()
                    idxx = batch['idx']
                    h_shift = batch['h_shift']
                    ori_shape = batch['ori_shape']
                    a_shift = batch['a_shift']
                    flip = batch['flip']
                images = batch['image'].cuda()
                img_index = batch['img_index'].cuda()
                bev_index = batch['bev_index'].cuda()
            else:
                inputs = batch['X'].cuda()
            class_labels = batch['cl'].cuda()
            reg_labels = batch['rl'].cuda()

            if args.pixor_fusion:
                if not args.e2e:
                    class_outs, reg_outs = pixor(inputs, images, img_index,
                                                 bev_index)
                else:
                    depth_loss, depth_map = forward_depth_model(
                        imgL, imgR, depth_map, f, train_metric, depth_model)
                    inputs = []
                    for i in range(depth_map.shape[0]):
                        calib = utils_func.torchCalib(
                            train_data.dataset.get_calibration(idxx[i]),
                            h_shift[i])
                        H, W = ori_shape[0][i], ori_shape[1][i]
                        depth = depth_map[i][-H:, :W]
                        ptc = depth_to_pcl(calib, depth, max_high=1.)
                        ptc = calib.lidar_to_rect(ptc[:, 0:3])

                        if torch.abs(a_shift[i]).item() > 1e-6:
                            roty = utils_func.roty_pth(a_shift[i]).cuda()
                            ptc = torch.mm(ptc, roty.t())
                        voxel = gen_feature_diffused_tensor(
                            ptc,
                            700,
                            800,
                            grid_3D_extended,
                            diffused=args.diffused)

                        if flip[i] > 0:
                            voxel = torch.flip(voxel, [2])

                        inputs.append(voxel)
                    inputs = torch.stack(inputs)
                    class_outs, reg_outs = pixor(inputs, images, img_index,
                                                 bev_index)
            else:
                class_outs, reg_outs = pixor(inputs)
            class_outs = class_outs.squeeze(1)
            class_loss, reg_loss, loss = \
                compute_loss(epoch, class_outs, reg_outs,
                    class_labels, reg_labels, class_criterion,
                    reg_criterion, args)
            avg_class_loss.update(class_loss.item())
            avg_reg_loss.update(reg_loss.item() \
                if not isinstance(reg_loss, int) else reg_loss)
            avg_total_loss.update(loss.item())

            optimizer.zero_grad()
            depth_optimizer.zero_grad()
            loss = depth_loss + 0.1 * loss
            loss.backward()
            optimizer.step()
            depth_optimizer.step()

            if not isinstance(reg_loss, int):
                reg_loss = reg_loss.item()

            if iteration % args.logevery == 0:
                logger.info("epoch {:d}, iter {:d}, class_loss: {:.5f},"
                            " reg_loss: {:.5f}, loss: {:.5f}".format(
                                epoch, iteration, avg_class_loss.avg,
                                avg_reg_loss.avg, avg_total_loss.avg))

                logger.info(train_metric.print(epoch, iteration))

        logger.info("Finish epoch {}, time elapsed {:.3f} s".format(
            epoch,
            time.time() - ts))

        if epoch % args.eval_every_epoch == 0 and epoch >= args.start_eval:
            logger.info("Evaluation begins at epoch {}".format(epoch))
            evaluate(eval_data,
                     eval_loader,
                     pixor,
                     depth_model,
                     args.batch_size,
                     gpu=use_gpu,
                     logger=logger,
                     args=args,
                     epoch=epoch,
                     processes=processes,
                     grid_3D_extended=grid_3D_extended)
            if args.run_official_evaluate:
                last_eval_epoches.append((epoch, 7))
                last_eval_epoches.append((epoch, 5))

        if len(last_eval_epoches) > 0:
            for e, iou in last_eval_epoches[:]:
                predicted_results = osp.join(args.saverootpath, args.run_name,
                                             'predicted_label_{}'.format(e),
                                             'outputs_{:02d}.txt'.format(iou))
                if osp.exists(predicted_results):
                    with open(predicted_results, 'r') as f:
                        for line in f.readlines():
                            if line.startswith('car_detection_ground AP'):
                                results = [
                                    float(num)
                                    for num in line.strip('\n').split(' ')[-3:]
                                ]
                                last_eval_epoches.remove((e, iou))

        if epoch % args.save_every == 0:
            saveto = osp.join(savepath, "checkpoint_{}.pth.tar".format(epoch))
            torch.save(
                {
                    'state_dict': pixor.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'depth_state_dict': depth_model.state_dict(),
                    'depth_optimizer': depth_optimizer.state_dict(),
                    'depth_scheduler': depth_scheduler.state_dict(),
                    'epoch': epoch
                }, saveto)
            logger.info("model saved to {}".format(saveto))
            symlink_force(saveto, osp.join(savepath, "checkpoint.pth.tar"))

    for p in processes:
        if p.wait() != 0:
            logger.warning("There was an error")

    if len(last_eval_epoches) > 0:
        for e, iou in last_eval_epoches[:]:
            predicted_results = osp.join(args.saverootpath, args.run_name,
                                         'predicted_label_{}'.format(e),
                                         'outputs_{:02d}.txt'.format(iou))
            if osp.exists(predicted_results):
                with open(predicted_results, 'r') as f:
                    for line in f.readlines():
                        if line.startswith('car_detection_ground AP'):
                            results = [
                                float(num)
                                for num in line.strip('\n').split(' ')[-3:]
                            ]
                            last_eval_epoches.remove((e, iou))