Training loops

Quick keras-like training using skorch

import skorch
from skorch import NeuralNetClassifier

# ---

# Data loading
# X,y = ...

# Model creation. the model is a regular pytorch module
# model = ...

# ---

# Dataset should be in one big numpy.ndarray
X = X.cpu().numpy().astype(numpy.float32)
y = Y.cpu().squeeze().numpy().astype(numpy.int64)

# Train a classifier
net = NeuralNetClassifier(
    model,
    max_epochs=300,
    lr=0.001,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    callbacks=[
        ('early_s', skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=20, threshold=0.0001))
    ],
    device='cuda',
    # below are options specific to the model, with `module__` preprended to the options names
    module__n_classes=n_classes,
    module__dropout_probability=dropout_rate,
)
net.fit(X, y)

# Save trained model
filename = 'networks_params.pt'
torch.save(net.module_.state_dict(), filename)

GAN training loop

# TODO: check if it works

def train_conditional_gan_step(G, D, d_criterion, true_example, condition,
    z_dim=128,
    device='cuda',
    batch_size=32, 
    n_critic=1,
    clip_w_discriminator=None, clip_w_generator=None,
    clip_gradients_discriminator=None, clip_gradients_generator=None):
    # -------------------------------------------
    # True/False
    # -------------------------------------------
    D_labels = torch.ones(batch_size, 1).to(device) # Discriminator Label to real
    D_fakes = torch.zeros(batch_size, 1).to(device) # Discriminator Label to fake

    # -------------------------------------------
    # DISCRIMINATOR
    # -------------------------------------------
    # Note: this assumes condition is already encoded as a one-hot vector
    x_outputs = D(true_example, condition)
    D_x_loss = d_criterion(x_outputs, D_labels)

    z = torch.randn(batch_size, z_dim).to(device)
    zg = G(z, condition).detach()
    z_outputs = D(zg, condition)
    D_z_loss = d_criterion(z_outputs, D_fakes)
    D_loss = D_x_loss + D_z_loss

    D.zero_grad()
    D_loss.backward()

    # Clip gradients of discriminator
    if clip_gradients_discriminator is not None:
        torch.nn.utils.clip_grad_value_(D.parameters(), clip_gradients_discriminator)  # usually: 0.1   
    # Clip weights of discriminator
    if clip_w_discriminator is not None:
        for p in D.parameters():
            p.data.clamp_(-clip_w_discriminator, +clip_w_discriminator)

    D_opt.step()

    # -------------------------------------------
    # GENERATOR
    # -------------------------------------------
    if step % n_critic == 0:
        z = torch.randn(batch_size, z_dim).to(device)
        z_outputs = D(G(z, condition), condition)
        G_loss = d_criterion(z_outputs, D_labels)

        G.zero_grad()
        G_loss.backward()

        # Clip gradients of generator
        if clip_gradients_discriminator is not None:
            torch.nn.utils.clip_grad_value_(G.parameters(), clip_gradients_generator)  # usually: 0.1
        # Clip weights of generator
        if clip_w_generator is not None:
            for p in G.parameters():
                p.data.clamp_(-clip_w_generator, +clip_w_generator)

        G_opt.step()

    dloss = D_loss.item()
    gloss = G_loss.item()

    return dloss, gloss
# TODO: update everything below (for now, just copy and paste)


config = {
    'device': device_name,
    'D_params_clip_value': 0.01,
    'n_critic': 1,
    'batch_size': 256,
    'dim_noise': 128,
    'adam_lr_D': 4e-4,
    'adam_lr_G': 4e-4,
    'adam_beta_1_G': 0.9,
    'adam_beta_2_G': 0.999,
    'adam_beta_1_D': 0.9,
    'adam_beta_2_D': 0.999,
    'training_max_epoch': 5000
}

# create G, D, opts

# start recorders
# epoch loop
#   - pre epoch
#   - batch loop
#     - pre batch
#     - train step
#     - post-batch
#   - post-epoch

G = Generator(config['dim_noise']).to(device)
D = Discriminator().to(device)

d_criterion = torch.nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=config['adam_lr_D'], betas=(config['adam_beta_1_D'], config['adam_beta_2_D']))
G_opt = torch.optim.Adam(G.parameters(), lr=config['adam_lr_G'], betas=(config['adam_beta_1_G'], config['adam_beta_2_G']))

config['max_epoch'] = 5000 # need more than 10 epochs for training generator
step = 0

start = time.time()

print('[INFO] Started training on {} ...'.format(datetime.datetime.now()))
print('[INFO] Device: {}'.format(device_name))
print('[INFO] Generator params: {:,}'.format(count_parameters(G)))
print('[INFO] Discriminator params: {:,}'.format(count_parameters(D)))

tensorboard_logger = SummaryWriter(flush_secs=10)
tensorboard_logger.add_hparams(hparam_dict=config, metric_dict={})

path_images = Path(tensorboard_logger.get_logdir(), 'generated_sequences_images')
path_images.mkdir()
path_videos = Path(tensorboard_logger.get_logdir(), 'generated_sequences_videos')
path_videos.mkdir()
path_params = Path(tensorboard_logger.get_logdir(), 'networks_params')
path_params.mkdir()
with open(tensorboard_logger.get_logdir() + '/' + "generator.py", "w") as text_file:
    text_file.write(source_code_G)
with open(tensorboard_logger.get_logdir() + '/' + "discriminator.py", "w") as text_file:
    text_file.write(source_code_D)
save_model = True

print('[INFO] Logdir: {}'.format(tensorboard_logger.get_logdir()))

# tensorboard_logger.add_graph(G, torch.randn(1, config['dim_noise']).to(device))
# tensorboard_logger.add_graph(D, x_train[0].to(device))

print('[INFO] Logdir: {}'.format(tensorboard_logger.get_logdir()))
for epoch in tqdm.tqdm_notebook(range(step, step + config['max_epoch']), desc='Epoch'):

    if step == 100:
        for param in G.decoder.parameters():
            param.requires_grad = True

    for idx, (sequences, labels) in enumerate(tqdm.tqdm_notebook(data_loader, desc='Epoch', leave=False)):

        # Training Discriminator
        x = sequences.to(device)
        y = to_onehot(labels).to(device)

        dloss, gloss = train_conditional_gan_step(G, D, d_criterion, x, y,
            z_dim=config['dim_noise'],
            device=config['device'],
            batch_size=config['batch_size'], 
            n_critic=config['n_critic'],
            clip_w_discriminator=None,
            clip_w_generator=None,
            clip_gradients_discriminator=None,
            clip_gradients_generator=None):

        tensorboard_logger.add_scalar('loss/loss_D', dloss, step)
        tensorboard_logger.add_scalar('loss/loss_G', gloss, step)

        if step % 50 == 0:
            print('Epoch: {}/{} | Step: {} | Time elapsed: {} | D Loss: {}, G Loss: {}'.format(epoch, config['max_epoch'], step, time_since(start), D_loss.item(), G_loss.item()))

        view_every = 10 # 50 # 10
        if step % view_every == 0: # and step != 0:
            G.eval()
            for c in range(config['condition_size']):
                cond = torch.zeros(4, 1)
                cond[:, 0] = c
                cond = to_onehot(cond)
                cond = cond.to(device)
                generated_sequences = get_sample(G, config['dim_noise'], cond, how_many=4, as_numpy=False)
                figure_filename = str(Path(path_images , 'generated_sequences_class_{:02d}_image_step_{:05d}.jpg'.format(c, step)))
                video_filename = str(Path(path_videos , 'generated_sequences_class_{:02d}_video_step_{:05d}.mp4'.format(c, step)))
                plot_four_examples(generated_sequences, save_as=figure_filename, display=False)
                display_animation(block_to_alternated(generated_sequences[0]), display=False, save=True, filename=video_filename,
                                 class_id=c, step=step)
            print('   ...mean=', generated_sequences.mean())
            G.train()

        save_every = 40 # 500 # 50
        if save_model is True and step % save_every == 0 and step != 0:
            path_params = Path(tensorboard_logger.get_logdir(), 'networks_params', 'step__{:04d}'.format(step))
            path_params.mkdir()

            d_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__D.pkl'.format(step, version_notebook, model_name)))
            g_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__G.pkl'.format(step, version_notebook, model_name)))
            d_opt_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__D_opt.pkl'.format(step, version_notebook, model_name)))
            g_opt_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__G_opt.pkl'.format(step, version_notebook, model_name)))
            config_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__config.pkl'.format(step, version_notebook, model_name)))

            torch.save(D.state_dict(), d_filename)
            torch.save(G.state_dict(), g_filename)

            torch.save(D_opt.state_dict(), d_opt_filename)
            torch.save(G_opt.state_dict(), g_opt_filename)

            torch.save(config, config_filename)

            print('[INFO] Saved model in folder {}'.format(path_params))

        step += 1